diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index b7f715aca0..edbb200cdc 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -5,7 +5,7 @@ Target mappers – domain → DTO translation for target-related models. """ -from pyrit.backend.models.targets import TargetCapabilitiesInfo, TargetInstance +from pyrit.models.catalog.target import TargetCapabilitiesInfo, TargetInstance from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.round_robin_target import RoundRobinTarget diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 3ab8571505..a5f4f69839 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -48,20 +48,15 @@ PreviewStep, ) from pyrit.backend.models.initializers import ( - InitializerParameterSummary, ListRegisteredInitializersResponse, - RegisteredInitializer, RegisterInitializerRequest, ) from pyrit.backend.models.scenarios import ( ListRegisteredScenariosResponse, - RegisteredScenario, - ScenarioParameterSummary, + ScenarioRunListResponse, ) from pyrit.backend.models.targets import ( CreateTargetRequest, - TargetCapabilitiesInfo, - TargetInstance, TargetListResponse, ) @@ -105,16 +100,11 @@ "PreviewStep", # Scenarios "ListRegisteredScenariosResponse", - "RegisteredScenario", - "ScenarioParameterSummary", + "ScenarioRunListResponse", # Initializers - "InitializerParameterSummary", "ListRegisteredInitializersResponse", - "RegisteredInitializer", "RegisterInitializerRequest", # Targets "CreateTargetRequest", - "TargetCapabilitiesInfo", - "TargetInstance", "TargetListResponse", ] diff --git a/pyrit/backend/models/initializers.py b/pyrit/backend/models/initializers.py index 49991fd0c4..5d80867ee5 100644 --- a/pyrit/backend/models/initializers.py +++ b/pyrit/backend/models/initializers.py @@ -2,38 +2,23 @@ # Licensed under the MIT license. """ -Initializer API response models. +REST envelopes for the initializer endpoints. -Initializers configure the PyRIT environment (targets, datasets, env vars) -before scenario execution. These models represent initializer metadata. +Canonical initializer catalog types (``RegisteredInitializer``, +``InitializerParameterSummary``) live in ``pyrit.models.catalog.initializer`` +and should be imported from there directly. """ from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo from pyrit.models import REGISTRY_NAME_PATTERN +from pyrit.models.catalog.initializer import RegisteredInitializer - -class InitializerParameterSummary(BaseModel): - """Summary of an initializer-declared parameter.""" - - name: str = Field(..., description="Parameter name") - description: str = Field(..., description="Human-readable description of the parameter") - default: list[str] | None = Field(None, description="Default value(s), or None if required") - - -class RegisteredInitializer(BaseModel): - """Summary of a registered initializer.""" - - initializer_name: str = Field(..., description="Initializer registry name (e.g., 'target')") - initializer_type: str = Field(..., description="Initializer class name (e.g., 'TargetInitializer')") - description: str = Field("", description="Human-readable description of the initializer") - required_env_vars: list[str] = Field( - default_factory=list, description="Environment variables required by this initializer" - ) - supported_parameters: list[InitializerParameterSummary] = Field( - default_factory=list, description="Parameters accepted by this initializer" - ) +__all__ = [ + "ListRegisteredInitializersResponse", + "RegisterInitializerRequest", +] class ListRegisteredInitializersResponse(BaseModel): diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index aaac688cf0..1fa0690cbb 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -2,48 +2,23 @@ # Licensed under the MIT license. """ -Scenario API response models. +REST envelopes for the scenario endpoints. -Scenarios are multi-attack security testing campaigns. These models represent -the metadata about available scenarios (listing) and scenario execution (runs). +Canonical scenario catalog/run types (``RegisteredScenario``, +``ScenarioParameterSummary``, ``ScenarioRunSummary``, ``RunScenarioRequest``) +live in ``pyrit.models.catalog.scenario`` and should be imported from there +directly. """ -from datetime import datetime -from enum import Enum -from typing import Any - from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo +from pyrit.models.catalog.scenario import RegisteredScenario, ScenarioRunSummary - -class ScenarioParameterSummary(BaseModel): - """Summary of a scenario-declared parameter.""" - - name: str = Field(..., description="Parameter name (e.g., 'max_turns')") - description: str = Field(..., description="Human-readable description of the parameter") - default: str | None = Field(None, description="Default value as a display string, or None if required") - param_type: str = Field(..., description="Type of the parameter as a display string (e.g., 'int', 'str')") - choices: list[str] | None = Field(None, description="Allowed values as strings, or None if unconstrained") - is_list: bool = Field(False, description="True when the parameter accepts a list of values (e.g., list[str])") - - -class RegisteredScenario(BaseModel): - """Summary of a registered scenario.""" - - scenario_name: str = Field(..., description="Scenario name (e.g., 'foundry.red_team_agent')") - scenario_type: str = Field(..., description="Scenario type identifier (e.g., 'RedTeamAgentScenario')") - description: str = Field(..., description="Human-readable description of the scenario") - default_strategy: str = Field(..., description="Default strategy name used when none specified") - aggregate_strategies: list[str] = Field( - ..., description="Aggregate strategies that combine multiple attack approaches" - ) - all_strategies: list[str] = Field(..., description="All available concrete strategy names") - default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") - max_dataset_size: int | None = Field(None, description="Maximum items per dataset (None means unlimited)") - supported_parameters: list[ScenarioParameterSummary] = Field( - default_factory=list, description="Scenario-declared custom parameters" - ) +__all__ = [ + "ListRegisteredScenariosResponse", + "ScenarioRunListResponse", +] class ListRegisteredScenariosResponse(BaseModel): @@ -53,73 +28,6 @@ class ListRegisteredScenariosResponse(BaseModel): pagination: PaginationInfo = Field(..., description="Pagination metadata") -# ============================================================================ -# Scenario Run Models -# ============================================================================ - - -class ScenarioRunStatus(str, Enum): - """Status of a scenario run, aligned with core ScenarioRunState.""" - - CREATED = "CREATED" - INITIALIZING = "INITIALIZING" - IN_PROGRESS = "IN_PROGRESS" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - CANCELLED = "CANCELLED" - - -class RunScenarioRequest(BaseModel): - """Request body for starting a scenario run.""" - - scenario_name: str = Field(..., description="Scenario name (e.g., 'foundry.red_team_agent')") - target_name: str = Field(..., description="Name of a registered target from the TargetRegistry") - initializers: list[str] | None = Field( - None, description="Initializer names to run before scenario (e.g., ['target', 'load_default_datasets'])" - ) - strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") - dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)") - max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") - max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") - max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") - labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") - scenario_params: dict[str, Any] | None = Field( - None, - description="Custom parameters for the scenario (passed to scenario.set_params_from_args). " - "Keys are parameter names declared by the scenario's supported_parameters().", - ) - initializer_args: dict[str, dict[str, Any]] | None = Field( - None, - description="Per-initializer arguments keyed by initializer name. " - "Each value is a dict of args passed to that initializer's set_params_from_args(). " - "Example: {'target': {'endpoint': 'https://...'}}.", - ) - scenario_result_id: str | None = Field( - None, - description="Optional ID of an existing ScenarioResult to resume. " - "If provided, the scenario will resume from prior progress instead of starting fresh.", - ) - - -class ScenarioRunSummary(BaseModel): - """Response for a scenario run (status + result details).""" - - scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") - scenario_name: str = Field(..., description="Registry key of the scenario being run") - scenario_version: int = Field(0, ge=0, description="Version of the scenario") - status: ScenarioRunStatus = Field(..., description="Current run status") - created_at: datetime = Field(..., description="When the run was created") - updated_at: datetime = Field(..., description="When the run status last changed") - error: str | None = Field(None, description="Error message if status is FAILED") - error_type: str | None = Field(None, description="Exception class name if status is FAILED") - strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") - total_attacks: int = Field(0, ge=0, description="Total number of attack results persisted for this run") - completed_attacks: int = Field(0, ge=0, description="Number of attacks that reached a terminal outcome") - objective_achieved_rate: int = Field(0, ge=0, le=100, description="Success rate as percentage (0-100)") - labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") - completed_at: datetime | None = Field(None, description="When the scenario finished") - - class ScenarioRunListResponse(BaseModel): """Response for listing scenario runs.""" diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index 9e5ce0c56d..e31d1c362a 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -2,13 +2,11 @@ # Licensed under the MIT license. """ -Target instance models. +REST envelopes and write-request types for the target endpoints. -Targets have two concepts: -- Types: Static metadata bundled with frontend (from registry) -- Instances: Runtime objects created via API with specific configuration - -This module defines the Instance models for runtime target management. +Canonical target catalog types (``TargetInstance``, ``TargetCapabilitiesInfo``) +live in ``pyrit.models.catalog.target`` and should be imported from there +directly. """ from typing import Any, Literal @@ -16,58 +14,12 @@ from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo +from pyrit.models.catalog.target import TargetInstance - -class TargetCapabilitiesInfo(BaseModel): - """ - Wire-format snapshot of a target's capabilities. - - Mirrors the domain ``TargetCapabilities`` dataclass for API consumers - (notably the GUI). Modality combinations (``frozenset[frozenset[...]]``) - are flattened into sorted unique modality lists since the frontend uses - them only for per-piece modality checks. - """ - - supports_multi_turn: bool = Field(False, description="Target natively supports multi-turn conversations") - supports_multi_message_pieces: bool = Field( - False, description="Target supports multiple message pieces in a single request" - ) - supports_json_schema: bool = Field(False, description="Target can constrain output to a provided JSON schema") - supports_json_output: bool = Field(False, description="Target supports JSON output mode") - supports_editable_history: bool = Field(False, description="Target allows attack history to be modified") - supports_system_prompt: bool = Field(False, description="Target natively supports system prompts") - supported_input_modalities: list[str] = Field( - default_factory=lambda: ["text"], - description="Sorted unique input modality data types the target accepts (e.g., ['image_path', 'text'])", - ) - supported_output_modalities: list[str] = Field( - default_factory=lambda: ["text"], - description="Sorted unique output modality data types the target produces (e.g., ['audio_path', 'text'])", - ) - - -class TargetInstance(BaseModel): - """ - A runtime target instance. - - Created either by an initializer (at startup) or by user (via API). - Also used as the create-target response (same shape as GET). - """ - - target_registry_name: str = Field(..., description="Target registry key (e.g., 'azure_openai_chat')") - target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: str | None = Field(None, description="Target endpoint URL") - model_name: str | None = Field(None, description="Model or deployment name used in API calls") - underlying_model_name: str | None = Field(None, description="Underlying model name if different (e.g., 'gpt-4o')") - temperature: float | None = Field(None, description="Temperature parameter for generation") - top_p: float | None = Field(None, description="Top-p parameter for generation") - max_requests_per_minute: int | None = Field(None, description="Maximum requests per minute") - capabilities: TargetCapabilitiesInfo = Field(..., description="Structured snapshot of target capabilities") - target_specific_params: dict[str, Any] | None = Field(None, description="Additional target-specific parameters") - inner_targets: list["TargetInstance"] | None = Field( - None, description="Inner targets for composite targets like RoundRobinTarget" - ) - identifier_hash: str | None = Field(None, description="ComponentIdentifier content hash for duplicate detection") +__all__ = [ + "CreateTargetRequest", + "TargetListResponse", +] class TargetListResponse(BaseModel): diff --git a/pyrit/backend/routes/initializers.py b/pyrit/backend/routes/initializers.py index dae0db900e..9d2e4cbb1d 100644 --- a/pyrit/backend/routes/initializers.py +++ b/pyrit/backend/routes/initializers.py @@ -18,10 +18,10 @@ from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.initializers import ( ListRegisteredInitializersResponse, - RegisteredInitializer, RegisterInitializerRequest, ) from pyrit.backend.services.initializer_service import get_initializer_service +from pyrit.models.catalog.initializer import RegisteredInitializer router = APIRouter(prefix="/initializers", tags=["initializers"]) diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 941d8021fb..cbf08a019e 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -12,18 +12,22 @@ /api/scenarios/runs — scenario execution lifecycle """ +from typing import Any + from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.scenarios import ( ListRegisteredScenariosResponse, - RegisteredScenario, - RunScenarioRequest, ScenarioRunListResponse, - ScenarioRunSummary, ) from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service +from pyrit.models.catalog.scenario import ( + RegisteredScenario, + RunScenarioRequest, + ScenarioRunSummary, +) router = APIRouter(prefix="/scenarios", tags=["scenarios"]) @@ -199,7 +203,7 @@ async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: # 409: {"model": ProblemDetail, "description": "Run not yet completed"}, }, ) -async def get_scenario_run_results(scenario_result_id: str) -> dict: # pyrit-async-suffix-exempt +async def get_scenario_run_results(scenario_result_id: str) -> dict[str, Any]: # pyrit-async-suffix-exempt """ Get detailed results for a completed scenario run. diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index bea53ddef2..c55a6d9ff2 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -13,10 +13,10 @@ from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.targets import ( CreateTargetRequest, - TargetInstance, TargetListResponse, ) from pyrit.backend.services.target_service import get_target_service +from pyrit.models.catalog.target import TargetInstance router = APIRouter(prefix="/targets", tags=["targets"]) diff --git a/pyrit/backend/services/initializer_service.py b/pyrit/backend/services/initializer_service.py index 153ca59412..88182db6a1 100644 --- a/pyrit/backend/services/initializer_service.py +++ b/pyrit/backend/services/initializer_service.py @@ -13,8 +13,10 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.initializers import ( - InitializerParameterSummary, ListRegisteredInitializersResponse, +) +from pyrit.models.catalog.initializer import ( + InitializerParameterSummary, RegisteredInitializer, ) from pyrit.registry import InitializerMetadata, InitializerRegistry diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index c8236aceda..42fe006d8a 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -14,14 +14,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from pyrit.backend.models.scenarios import ( +from pyrit.backend.models.scenarios import ScenarioRunListResponse +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, ScenarioResult, ScenarioRunState +from pyrit.models.catalog.scenario import ( RunScenarioRequest, - ScenarioRunListResponse, - ScenarioRunStatus, ScenarioRunSummary, ) -from pyrit.memory import CentralMemory -from pyrit.models import AttackOutcome, ScenarioResult from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import Scenario from pyrit.scenario.core import DatasetConfiguration @@ -166,9 +165,9 @@ async def cancel_run_async(self, *, scenario_result_id: str) -> ScenarioRunSumma return None scenario_result = results[0] - db_status = ScenarioRunStatus(scenario_result.scenario_run_state) + db_status = ScenarioRunState(scenario_result.scenario_run_state) - if db_status in (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED): + if db_status in (ScenarioRunState.COMPLETED, ScenarioRunState.FAILED, ScenarioRunState.CANCELLED): raise ValueError(f"Cannot cancel run in '{db_status}' state.") # Cancel the asyncio task if active and wait for it to finish @@ -475,7 +474,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari if not error and active is not None: error = active.error - status = ScenarioRunStatus(scenario_result.scenario_run_state) + status = ScenarioRunState(scenario_result.scenario_run_state) # Build result fields from DB (always computed so in-progress runs show progress) total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) @@ -488,7 +487,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_version=scenario_result.scenario_identifier.version, status=status, created_at=scenario_result.creation_time, - updated_at=scenario_result.completion_time, + updated_at=scenario_result.completion_time or scenario_result.creation_time, error=error, error_type=error_type, strategies_used=strategies_used, @@ -519,7 +518,7 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioResult | None: scenario_result = results[0] run_response = self._build_response_from_db(scenario_result=scenario_result) - if run_response.status != ScenarioRunStatus.COMPLETED: + if run_response.status != ScenarioRunState.COMPLETED: raise ValueError(f"Results are only available for completed runs. Current status: '{run_response.status}'.") return scenario_result diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index 939b863306..a859fdc764 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -11,8 +11,8 @@ from functools import lru_cache from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ( - ListRegisteredScenariosResponse, +from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse +from pyrit.models.catalog.scenario import ( RegisteredScenario, ScenarioParameterSummary, ) diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 6663dfa57b..05f4279a25 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -24,9 +24,9 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.targets import ( CreateTargetRequest, - TargetInstance, TargetListResponse, ) +from pyrit.models.catalog.target import TargetInstance from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.openai.openai_target import OpenAITarget diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index eddad0f0f3..403caae898 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from pyrit.models.catalog import ScenarioParameterSummary from pyrit.setup.configuration_loader import ScenarioConfig # --------------------------------------------------------------------------- @@ -643,7 +644,7 @@ def extract_scenario_args(*, parsed: dict[str, Any]) -> dict[str, Any]: # --------------------------------------------------------------------------- -def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Parameter] | None: +def build_parameters_from_api(*, api_params: list[ScenarioParameterSummary]) -> list[Parameter] | None: """ Build ``Parameter`` objects from a scenario catalog's ``supported_parameters``. @@ -652,7 +653,7 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Param can apply per-element coercion and treat list params as ``multi_value``. Args: - api_params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. + api_params: Scenario-declared parameters from ``GET /api/scenarios/catalog/{name}``. Returns: list[Parameter] | None: Parameter list when ``api_params`` is non-empty, else ``None``. @@ -662,20 +663,19 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Param type_map: dict[str, Any] = {"int": int, "float": float, "bool": bool, "str": str} parameters: list[Parameter] = [] for p in api_params: - type_display = p.get("param_type", "") - if p.get("is_list"): + type_display = p.param_type + if p.is_list: element_type = type_map.get(type_display.removeprefix("list[").rstrip("]"), str) resolved_type: Any = list[element_type] # type: ignore[valid-type] else: resolved_type = type_map.get(type_display) - raw_choices = p.get("choices") - choices: tuple[Any, ...] | None = tuple(raw_choices) if raw_choices else None + choices: tuple[Any, ...] | None = tuple(p.choices) if p.choices else None parameters.append( Parameter( - name=p["name"], - description=p.get("description", ""), + name=p.name, + description=p.description, param_type=resolved_type, - default=p.get("default"), + default=p.default, choices=choices, ) ) diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py index 3580c8e5b6..5640edd275 100644 --- a/pyrit/cli/_output.py +++ b/pyrit/cli/_output.py @@ -4,14 +4,25 @@ """ Console output formatting for the PyRIT CLI thin client. -All functions accept plain ``dict`` payloads (deserialized JSON from the REST -API) and print human-readable output to stdout. No heavy pyrit imports. +All public ``print_*`` functions accept typed ``pyrit.models`` objects +(``RegisteredScenario``, ``RegisteredInitializer``, ``TargetInstance``, +``ScenarioRunSummary``, ``ScenarioResult``). The heavy ``pyrit.models`` +import is deferred to each function so importing this module stays cheap. """ from __future__ import annotations import sys -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.models import ScenarioResult + from pyrit.models.catalog import ( + RegisteredInitializer, + RegisteredScenario, + ScenarioRunSummary, + TargetInstance, + ) try: import termcolor @@ -67,12 +78,12 @@ def _wrap(*, text: str, indent: str, width: int = 78) -> str: # --------------------------------------------------------------------------- -def print_scenario_list(*, items: list[dict[str, Any]]) -> None: +def print_scenario_list(*, items: list[RegisteredScenario]) -> None: """ Print a formatted list of scenarios. Args: - items: List of scenario dicts from ``GET /api/scenarios/catalog``. + items: Scenarios from ``GET /api/scenarios/catalog``. """ if not items: print("No scenarios found.") @@ -81,39 +92,30 @@ def print_scenario_list(*, items: list[dict[str, Any]]) -> None: print("\nAvailable Scenarios:") print("=" * 80) for sc in items: - _header(sc.get("scenario_name", "unknown")) - print(f" Class: {sc.get('scenario_type', '')}") - desc = sc.get("description", "") - if desc: + _header(sc.scenario_name) + print(f" Class: {sc.scenario_type}") + if sc.description: print(" Description:") - print(_wrap(text=desc, indent=" ")) - agg = sc.get("aggregate_strategies") or [] - if agg: + print(_wrap(text=sc.description, indent=" ")) + if sc.aggregate_strategies: print(" Aggregate Strategies:") - print(_wrap(text=", ".join(agg), indent=" - ")) - strategies = sc.get("all_strategies") or [] - if strategies: - print(f" Available Strategies ({len(strategies)}):") - print(_wrap(text=", ".join(strategies), indent=" ")) - default_strat = sc.get("default_strategy") - if default_strat: - print(f" Default Strategy: {default_strat}") - datasets = sc.get("default_datasets") or [] - max_ds = sc.get("max_dataset_size") - if datasets: - suffix = f", max {max_ds} per dataset" if max_ds else "" - print(f" Default Datasets ({len(datasets)}{suffix}):") - print(_wrap(text=", ".join(datasets), indent=" ")) - params = sc.get("supported_parameters") or [] - if params: + print(_wrap(text=", ".join(sc.aggregate_strategies), indent=" - ")) + if sc.all_strategies: + print(f" Available Strategies ({len(sc.all_strategies)}):") + print(_wrap(text=", ".join(sc.all_strategies), indent=" ")) + if sc.default_strategy: + print(f" Default Strategy: {sc.default_strategy}") + if sc.default_datasets: + suffix = f", max {sc.max_dataset_size} per dataset" if sc.max_dataset_size else "" + print(f" Default Datasets ({len(sc.default_datasets)}{suffix}):") + print(_wrap(text=", ".join(sc.default_datasets), indent=" ")) + if sc.supported_parameters: print(" Supported Parameters:") - for p in params: - default_str = f" [default: {p.get('default')!r}]" if p.get("default") is not None else "" - type_str = f" ({p.get('param_type', '')})" if p.get("param_type") else "" - choices = p.get("choices") - choices_display = ", ".join(choices) if isinstance(choices, list) else choices - choices_str = f" [choices: {choices_display}]" if choices_display else "" - print(f" - {p.get('name', '?')}{type_str}{default_str}{choices_str}: {p.get('description', '')}") + for p in sc.supported_parameters: + default_str = f" [default: {p.default!r}]" if p.default is not None else "" + type_str = f" ({p.param_type})" if p.param_type else "" + choices_str = f" [choices: {', '.join(p.choices)}]" if p.choices else "" + print(f" - {p.name}{type_str}{default_str}{choices_str}: {p.description}") print("\n" + "=" * 80) print(f"\nTotal scenarios: {len(items)}") @@ -123,12 +125,12 @@ def print_scenario_list(*, items: list[dict[str, Any]]) -> None: # --------------------------------------------------------------------------- -def print_initializer_list(*, items: list[dict[str, Any]]) -> None: +def print_initializer_list(*, items: list[RegisteredInitializer]) -> None: """ Print a formatted list of initializers. Args: - items: List of initializer dicts from ``GET /api/initializers``. + items: Initializers from ``GET /api/initializers``. """ if not items: print("No initializers found.") @@ -137,25 +139,22 @@ def print_initializer_list(*, items: list[dict[str, Any]]) -> None: print("\nAvailable Initializers:") print("=" * 80) for init in items: - _header(init.get("initializer_name", "unknown")) - print(f" Class: {init.get('initializer_type', '')}") - env_vars = init.get("required_env_vars") or [] - if env_vars: + _header(init.initializer_name) + print(f" Class: {init.initializer_type}") + if init.required_env_vars: print(" Required Environment Variables:") - for var in env_vars: + for var in init.required_env_vars: print(f" - {var}") else: print(" Required Environment Variables: None") - params = init.get("supported_parameters") or [] - if params: + if init.supported_parameters: print(" Supported Parameters:") - for p in params: - default_str = f" [default: {p.get('default')}]" if p.get("default") else "" - print(f" - {p.get('name', '?')}{default_str}: {p.get('description', '')}") - desc = init.get("description", "") - if desc: + for p in init.supported_parameters: + default_str = f" [default: {p.default}]" if p.default else "" + print(f" - {p.name}{default_str}: {p.description}") + if init.description: print(" Description:") - print(_wrap(text=desc, indent=" ")) + print(_wrap(text=init.description, indent=" ")) print("\n" + "=" * 80) print(f"\nTotal initializers: {len(items)}") @@ -165,12 +164,12 @@ def print_initializer_list(*, items: list[dict[str, Any]]) -> None: # --------------------------------------------------------------------------- -def print_target_list(*, items: list[dict[str, Any]]) -> None: +def print_target_list(*, items: list[TargetInstance]) -> None: """ Print a formatted list of targets. Args: - items: List of target dicts from ``GET /api/targets``. + items: Targets from ``GET /api/targets``. """ if not items: print("\nNo targets found in registry.") @@ -183,14 +182,13 @@ def print_target_list(*, items: list[dict[str, Any]]) -> None: print("\nRegistered Targets:") print("=" * 80) for tgt in items: - _header(tgt.get("target_registry_name", "unknown")) - print(f" Class: {tgt.get('target_type', '')}") - model = tgt.get("underlying_model_name") or tgt.get("model_name") or "" + _header(tgt.target_registry_name) + print(f" Class: {tgt.target_type}") + model = tgt.underlying_model_name or tgt.model_name or "" if model: print(f" Model: {model}") - endpoint = tgt.get("endpoint") or "" - if endpoint: - print(f" Endpoint: {endpoint}") + if tgt.endpoint: + print(f" Endpoint: {tgt.endpoint}") print("\n" + "=" * 80) print(f"\nTotal targets: {len(items)}") @@ -200,19 +198,15 @@ def print_target_list(*, items: list[dict[str, Any]]) -> None: # --------------------------------------------------------------------------- -def print_scenario_run_progress(*, run: dict[str, Any], total_strategies: int = 0) -> None: +def print_scenario_run_progress(*, run: ScenarioRunSummary, total_strategies: int = 0) -> None: """ Print a single-line progress update (overwrites the current line). Args: - run: ScenarioRunSummary dict from ``GET /api/scenarios/runs/{id}``. + run: ``ScenarioRunSummary`` from ``GET /api/scenarios/runs/{id}``. total_strategies: Total number of strategies expected (0 if unknown). """ - run_status = run.get("status", "UNKNOWN") - total = run.get("total_attacks", 0) - completed = run.get("completed_attacks", 0) - rate = run.get("objective_achieved_rate", 0) - strategies_done = len(run.get("strategies_used") or []) + strategies_done = len(run.strategies_used) # Strategies the user passed may be aggregates that expand on the server # (e.g. `single_turn` -> N concrete strategies). Trust whichever count is larger. effective_total = max(total_strategies, strategies_done) @@ -224,52 +218,43 @@ def print_scenario_run_progress(*, run: dict[str, Any], total_strategies: int = elif strategies_done > 0: parts.append(f"strategies: {strategies_done}") - if total > 0: - pct = int((completed / total) * 100) + if run.total_attacks > 0: + pct = int((run.completed_attacks / run.total_attacks) * 100) bar_width = 30 - filled = int(bar_width * completed / total) + filled = int(bar_width * run.completed_attacks / run.total_attacks) bar = "█" * filled + "░" * (bar_width - filled) - parts.append(f"[{bar}] {completed}/{total} attacks ({pct}%)") + parts.append(f"[{bar}] {run.completed_attacks}/{run.total_attacks} attacks ({pct}%)") else: - parts.append(f"attacks: {completed}") + parts.append(f"attacks: {run.completed_attacks}") - parts.append(f"success rate: {rate}%") - parts.append(run_status) + parts.append(f"success rate: {run.objective_achieved_rate}%") + parts.append(run.status.value) line = "\r " + " | ".join(parts) sys.stdout.write(line) sys.stdout.flush() -def print_scenario_run_summary(*, run: dict[str, Any]) -> None: +def print_scenario_run_summary(*, run: ScenarioRunSummary) -> None: """ Print a brief summary of a completed scenario run. Args: - run: ScenarioRunSummary dict. + run: ``ScenarioRunSummary``. """ print() # newline after progress bar - status = run.get("status", "UNKNOWN") - name = run.get("scenario_name", "unknown") - rid = run.get("scenario_result_id", "?") - total = run.get("total_attacks", 0) - completed = run.get("completed_attacks", 0) - rate = run.get("objective_achieved_rate", 0) - - print(f"\nScenario: {name}") - print(f" Result ID: {rid}") - print(f" Status: {status}") - print(f" Total Attacks: {total}") - print(f" Completed: {completed}") - print(f" Success Rate: {rate}%") + print(f"\nScenario: {run.scenario_name}") + print(f" Result ID: {run.scenario_result_id}") + print(f" Status: {run.status.value}") + print(f" Total Attacks: {run.total_attacks}") + print(f" Completed: {run.completed_attacks}") + print(f" Success Rate: {run.objective_achieved_rate}%") - error = run.get("error") - if error: - print(f" Error: {error}") + if run.error: + print(f" Error: {run.error}") - strategies = run.get("strategies_used") or [] - if strategies: - print(f" Strategies: {', '.join(strategies)}") + if run.strategies_used: + print(f" Strategies: {', '.join(run.strategies_used)}") # --------------------------------------------------------------------------- @@ -277,19 +262,17 @@ def print_scenario_run_summary(*, run: dict[str, Any]) -> None: # --------------------------------------------------------------------------- -async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: +async def print_scenario_result_async(*, result: ScenarioResult) -> None: """ Print detailed scenario results using the output module. Args: - result_dict: ``ScenarioResult.to_dict()`` payload from the REST API. + result: Deserialized ``ScenarioResult`` from the REST API. """ - from pyrit.models.scenario_result import ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - scenario_result = ScenarioResult.from_dict(result_dict) printer = PrettyScenarioResultMemoryPrinter() - await printer.write_async(scenario_result) + await printer.write_async(result) # --------------------------------------------------------------------------- @@ -297,12 +280,12 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: # --------------------------------------------------------------------------- -def print_scenario_runs_list(*, runs: list[dict[str, Any]]) -> None: +def print_scenario_runs_list(*, runs: list[ScenarioRunSummary]) -> None: """ Print a list of scenario run summaries. Args: - runs: List of ScenarioRunSummary dicts from ``GET /api/scenarios/runs``. + runs: Scenario runs from ``GET /api/scenarios/runs``. """ if not runs: print("No scenario runs found.") @@ -311,13 +294,12 @@ def print_scenario_runs_list(*, runs: list[dict[str, Any]]) -> None: print("\nScenario Run History:") print("=" * 80) for idx, run in enumerate(runs, start=1): - status = run.get("status", "?") - name = run.get("scenario_name", "unknown") - rid = run.get("scenario_result_id", "?")[:8] - total = run.get("total_attacks", 0) - rate = run.get("objective_achieved_rate", 0) - created = run.get("created_at", "?") - print(f" {idx}) [{status}] {name} (id: {rid}…) — {total} attacks, {rate}% success — {created}") + rid = run.scenario_result_id[:8] + created = run.created_at.isoformat() if run.created_at else "?" + print( + f" {idx}) [{run.status.value}] {run.scenario_name} (id: {rid}…) — " + f"{run.total_attacks} attacks, {run.objective_achieved_rate}% success — {created}" + ) print("=" * 80) print(f"\nTotal runs: {len(runs)}") diff --git a/pyrit/cli/api_client.py b/pyrit/cli/api_client.py index bfd75ca420..98ce767f66 100644 --- a/pyrit/cli/api_client.py +++ b/pyrit/cli/api_client.py @@ -4,15 +4,26 @@ """ Async REST client for the PyRIT backend API. -Uses ``httpx`` internally but defers the import to method calls so that -importing this module does not trigger the import-guard ban on ``httpx`` -at CLI parse time. +Returns typed ``pyrit.models`` objects (canonical wire-data types defined in +``pyrit.models.catalog`` plus ``ScenarioResult``). Heavy imports — ``httpx`` +and ``pyrit.models`` — are deferred to method bodies so that importing this +module does not trigger the CLI parse-time import-guard ban on either. """ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pyrit.models import ScenarioResult + from pyrit.models.catalog import ( + RegisteredInitializer, + RegisteredScenario, + RunScenarioRequest, + ScenarioRunSummary, + TargetInstance, + ) _logger = logging.getLogger(__name__) @@ -92,48 +103,57 @@ async def health_check_async(self) -> bool: # Scenarios # ------------------------------------------------------------------ - async def list_scenarios_async(self, *, limit: int = 200) -> dict[str, Any]: + async def list_scenarios_async(self, *, limit: int = 200) -> list[RegisteredScenario]: """ List all available scenarios. Returns: - dict: ``ListRegisteredScenariosResponse`` payload. + list[RegisteredScenario]: All scenarios in the catalog. """ - return await self._get_json_async(path="/api/scenarios/catalog", params={"limit": limit}) + from pyrit.models.catalog import RegisteredScenario - async def get_scenario_async(self, *, scenario_name: str) -> dict[str, Any] | None: + payload = await self._get_json_async(path="/api/scenarios/catalog", params={"limit": limit}) + return [RegisteredScenario.model_validate(item) for item in payload.get("items", [])] + + async def get_scenario_async(self, *, scenario_name: str) -> RegisteredScenario | None: """ Get metadata for a single scenario. Returns: - dict | None: ``RegisteredScenario`` payload, or ``None`` if 404. + RegisteredScenario | None: The scenario, or ``None`` if 404. Raises: httpx.HTTPStatusError: For non-404 HTTP error responses. """ import httpx + from pyrit.models.catalog import RegisteredScenario + try: - return await self._get_json_async(path=f"/api/scenarios/catalog/{scenario_name}") + payload = await self._get_json_async(path=f"/api/scenarios/catalog/{scenario_name}") except httpx.HTTPStatusError as exc: if exc.response.status_code == 404: return None raise + return RegisteredScenario.model_validate(payload) # ------------------------------------------------------------------ # Initializers # ------------------------------------------------------------------ - async def list_initializers_async(self, *, limit: int = 200) -> dict[str, Any]: + async def list_initializers_async(self, *, limit: int = 200) -> list[RegisteredInitializer]: """ List all available initializers. Returns: - dict: ``ListRegisteredInitializersResponse`` payload. + list[RegisteredInitializer]: All initializers in the catalog. """ - return await self._get_json_async(path="/api/initializers", params={"limit": limit}) + from pyrit.models.catalog import RegisteredInitializer - async def register_initializer_async(self, *, name: str, script_content: str) -> dict[str, Any]: + payload = await self._get_json_async(path="/api/initializers", params={"limit": limit}) + return [RegisteredInitializer.model_validate(item) for item in payload.get("items", [])] + + async def register_initializer_async(self, *, name: str, script_content: str) -> RegisteredInitializer: """ Register a custom initializer by uploading Python source code. @@ -142,11 +162,13 @@ async def register_initializer_async(self, *, name: str, script_content: str) -> script_content: Python source code containing a ``PyRITInitializer`` subclass. Returns: - dict: ``RegisteredInitializer`` payload. + RegisteredInitializer: The newly registered initializer. Raises: ServerNotAvailableError: If custom initializers are disabled (403). """ + from pyrit.models.catalog import RegisteredInitializer + client = self._get_client() resp = await client.post( "/api/initializers", @@ -156,41 +178,49 @@ async def register_initializer_async(self, *, name: str, script_content: str) -> detail = resp.json().get("detail", "Custom initializer operations are disabled on the server.") raise ServerNotAvailableError(detail) self._raise_for_status(resp) - return resp.json() + return RegisteredInitializer.model_validate(resp.json()) # ------------------------------------------------------------------ # Targets # ------------------------------------------------------------------ - async def list_targets_async(self, *, limit: int = 200) -> dict[str, Any]: + async def list_targets_async(self, *, limit: int = 200) -> list[TargetInstance]: """ List all available targets. Returns: - dict: ``TargetListResponse`` payload. + list[TargetInstance]: All targets registered on the server. """ - return await self._get_json_async(path="/api/targets", params={"limit": limit}) + from pyrit.models.catalog import TargetInstance + + payload = await self._get_json_async(path="/api/targets", params={"limit": limit}) + return [TargetInstance.model_validate(item) for item in payload.get("items", [])] # ------------------------------------------------------------------ # Scenario runs # ------------------------------------------------------------------ - async def start_scenario_run_async(self, *, request: dict[str, Any]) -> dict[str, Any]: + async def start_scenario_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSummary: """ Start a new scenario run. Args: - request: ``RunScenarioRequest``-shaped dict. + request: Typed run request describing the scenario, initializers, and overrides. Returns: - dict: ``ScenarioRunSummary`` payload. + ScenarioRunSummary: The newly-created scenario run. """ + from pyrit.models.catalog import ScenarioRunSummary + client = self._get_client() - resp = await client.post("/api/scenarios/runs", json=request) + resp = await client.post( + "/api/scenarios/runs", + json=request.model_dump(mode="json", exclude_none=True), + ) self._raise_for_status(resp) - return resp.json() + return ScenarioRunSummary.model_validate(resp.json()) - async def get_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, Any]: + async def get_scenario_run_async(self, *, scenario_result_id: str) -> ScenarioRunSummary: """ Get the current status of a scenario run. @@ -200,13 +230,15 @@ async def get_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, default read timeout. The other endpoints keep the configured timeout. Returns: - dict: ``ScenarioRunSummary`` payload. + ScenarioRunSummary: The current state of the scenario run. Raises: ServerNotAvailableError: If the server cannot be reached. """ import httpx + from pyrit.models.catalog import ScenarioRunSummary + client = self._get_client() try: resp = await client.get( @@ -221,37 +253,45 @@ async def get_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, "or pass '--server-url '." ) from exc self._raise_for_status(resp) - return resp.json() + return ScenarioRunSummary.model_validate(resp.json()) - async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> dict[str, Any]: + async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> ScenarioResult: """ Get detailed results for a completed scenario run. Returns: - dict: ``ScenarioResult.to_dict()`` payload. + ScenarioResult: The full scenario result deserialized from the server payload. """ - return await self._get_json_async(path=f"/api/scenarios/runs/{scenario_result_id}/results") + from pyrit.models import ScenarioResult - async def cancel_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, Any]: + payload = await self._get_json_async(path=f"/api/scenarios/runs/{scenario_result_id}/results") + return ScenarioResult.model_validate(payload) + + async def cancel_scenario_run_async(self, *, scenario_result_id: str) -> ScenarioRunSummary: """ Cancel a running scenario. Returns: - dict: Updated ``ScenarioRunSummary`` payload. + ScenarioRunSummary: Updated summary reflecting the cancellation request. """ + from pyrit.models.catalog import ScenarioRunSummary + client = self._get_client() resp = await client.post(f"/api/scenarios/runs/{scenario_result_id}/cancel") self._raise_for_status(resp) - return resp.json() + return ScenarioRunSummary.model_validate(resp.json()) - async def list_scenario_runs_async(self, *, limit: int = 100) -> dict[str, Any]: + async def list_scenario_runs_async(self, *, limit: int = 100) -> list[ScenarioRunSummary]: """ List tracked scenario runs. Returns: - dict: ``ScenarioRunListResponse`` payload. + list[ScenarioRunSummary]: All tracked scenario runs. """ - return await self._get_json_async(path="/api/scenarios/runs", params={"limit": limit}) + from pyrit.models.catalog import ScenarioRunSummary + + payload = await self._get_json_async(path="/api/scenarios/runs", params={"limit": limit}) + return [ScenarioRunSummary.model_validate(item) for item in payload.get("items", [])] # ------------------------------------------------------------------ # Lifecycle diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 1e4467f929..50591f6758 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -17,7 +17,7 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from pyrit.cli._cli_args import ( ARG_HELP, @@ -27,6 +27,14 @@ validate_log_level_argparse, ) +if TYPE_CHECKING: + from pyrit.models.catalog import ( + RegisteredScenario, + RunScenarioRequest, + ScenarioParameterSummary, + ScenarioRunSummary, + ) + _TERMINAL_STATUSES = {"COMPLETED", "FAILED", "CANCELLED"} @@ -256,13 +264,13 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: } -def _scenario_param_kwargs(*, param: dict[str, Any]) -> dict[str, Any]: +def _scenario_param_kwargs(*, param: ScenarioParameterSummary) -> dict[str, Any]: """ - Build argparse ``add_argument`` kwargs for a scenario-declared parameter dict. + Build argparse ``add_argument`` kwargs for a scenario-declared parameter. - Uses ``param_type``, ``is_list`` and ``choices`` from the catalog payload - so list params accept ``nargs='+'`` and scalar params get client-side - type coercion and choice validation. + Uses ``param_type``, ``is_list`` and ``choices`` so list params accept + ``nargs='+'`` and scalar params get client-side type coercion and choice + validation. Args: param: Single entry from ``RegisteredScenario.supported_parameters``. @@ -271,16 +279,16 @@ def _scenario_param_kwargs(*, param: dict[str, Any]) -> dict[str, Any]: dict[str, Any]: kwargs ready to pass to ``ArgumentParser.add_argument``. """ kwargs: dict[str, Any] = { - "dest": f"{_SCENARIO_DEST_PREFIX}{param.get('name', '')}", + "dest": f"{_SCENARIO_DEST_PREFIX}{param.name}", "default": argparse.SUPPRESS, - "help": param.get("description", ""), + "help": param.description, } - if param.get("is_list"): + if param.is_list: kwargs["nargs"] = "+" else: - coercer = _SCALAR_TYPE_COERCERS.get(param.get("param_type", "")) + coercer = _SCALAR_TYPE_COERCERS.get(param.param_type) if coercer is not None and coercer is not str: - param_name = param.get("name", "") + param_name = param.name def _typed(raw: str) -> Any: try: @@ -291,24 +299,22 @@ def _typed(raw: str) -> Any: ) from exc kwargs["type"] = _typed - choices = param.get("choices") - if choices: - kwargs["choices"] = list(choices) + if param.choices: + kwargs["choices"] = list(param.choices) return kwargs -def _add_scenario_params_from_api(*, parser: ArgumentParser, params: list[dict[str, Any]]) -> None: +def _add_scenario_params_from_api(*, parser: ArgumentParser, params: list[ScenarioParameterSummary]) -> None: """ - Add scenario-declared parameters (from the API response) as CLI flags. + Add scenario-declared parameters as CLI flags. Args: parser: Parser to extend. - params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. + params: Scenario-declared parameters from ``GET /api/scenarios/catalog/{name}``. """ seen_flags: set[str] = set(parser._option_string_actions.keys()) for p in params: - name = p.get("name", "") - flag = f"--{name.replace('_', '-')}" + flag = f"--{p.name.replace('_', '-')}" if flag in seen_flags: continue parser.add_argument(flag, **_scenario_param_kwargs(param=p)) @@ -470,16 +476,16 @@ async def _handle_list_commands_async(*, client: Any, parsed_args: Namespace) -> from pyrit.cli import _output if parsed_args.list_scenarios: - resp = await client.list_scenarios_async() - _output.print_scenario_list(items=resp.get("items", [])) + scenarios = await client.list_scenarios_async() + _output.print_scenario_list(items=scenarios) return 0 if parsed_args.list_initializers: - resp = await client.list_initializers_async() - _output.print_initializer_list(items=resp.get("items", [])) + initializers = await client.list_initializers_async() + _output.print_initializer_list(items=initializers) return 0 if parsed_args.list_targets: - resp = await client.list_targets_async() - _output.print_target_list(items=resp.get("items", [])) + targets = await client.list_targets_async() + _output.print_target_list(items=targets) return 0 return None @@ -512,7 +518,7 @@ async def _handle_add_initializer_async(*, client: Any, parsed_args: Namespace) def _reparse_with_scenario_params( - *, parsed_args: Namespace, supported_params: list[dict[str, Any]] + *, parsed_args: Namespace, supported_params: list[ScenarioParameterSummary] ) -> Namespace | None: """ Re-parse the original args with scenario-declared flags added to the base parser. @@ -545,16 +551,17 @@ def _reparse_with_scenario_params( return None -def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> dict[str, Any]: +def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> RunScenarioRequest: """ - Build the ``RunScenarioRequest`` dict from parsed CLI args. + Build the ``RunScenarioRequest`` typed object from parsed CLI args. Returns: - dict[str, Any]: The request payload to send to ``POST /api/scenarios/runs``. + RunScenarioRequest: The typed request payload to send to ``POST /api/scenarios/runs``. """ from pyrit.cli._cli_args import parse_memory_labels + from pyrit.models.catalog import RunScenarioRequest - request: dict[str, Any] = { + kwargs: dict[str, Any] = { "scenario_name": scenario_name, "target_name": parsed_args.target or "", } @@ -570,28 +577,28 @@ def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> dict[st init_names.append(name) if entry.get("args"): init_args[name] = entry["args"] - request["initializers"] = init_names + kwargs["initializers"] = init_names if init_args: - request["initializer_args"] = init_args + kwargs["initializer_args"] = init_args if parsed_args.scenario_strategies: - request["strategies"] = parsed_args.scenario_strategies + kwargs["strategies"] = parsed_args.scenario_strategies if parsed_args.max_concurrency is not None: - request["max_concurrency"] = parsed_args.max_concurrency + kwargs["max_concurrency"] = parsed_args.max_concurrency if parsed_args.max_retries is not None: - request["max_retries"] = parsed_args.max_retries + kwargs["max_retries"] = parsed_args.max_retries if parsed_args.dataset_names: - request["dataset_names"] = parsed_args.dataset_names + kwargs["dataset_names"] = parsed_args.dataset_names if parsed_args.max_dataset_size is not None: - request["max_dataset_size"] = parsed_args.max_dataset_size + kwargs["max_dataset_size"] = parsed_args.max_dataset_size if parsed_args.memory_labels: - request["labels"] = parse_memory_labels(json_string=parsed_args.memory_labels) + kwargs["labels"] = parse_memory_labels(json_string=parsed_args.memory_labels) scenario_params = _extract_scenario_args(parsed=parsed_args) if scenario_params: - request["scenario_params"] = scenario_params + kwargs["scenario_params"] = scenario_params - return request + return RunScenarioRequest(**kwargs) async def _poll_until_terminal_async( @@ -599,20 +606,19 @@ async def _poll_until_terminal_async( client: Any, scenario_result_id: str, total_strategies: int, -) -> dict[str, Any]: +) -> ScenarioRunSummary: """ Poll the server until the run reaches a terminal status. Returns: - dict[str, Any]: The final run dict. + ScenarioRunSummary: The final run summary. """ from pyrit.cli import _output while True: run = await client.get_scenario_run_async(scenario_result_id=scenario_result_id) - status = run.get("status", "UNKNOWN") _output.print_scenario_run_progress(run=run, total_strategies=total_strategies) - if status in _TERMINAL_STATUSES: + if run.status in _TERMINAL_STATUSES: return run await asyncio.sleep(0.5) @@ -621,7 +627,7 @@ async def _run_scenario_async( *, client: Any, parsed_args: Namespace, - scenario_meta: dict[str, Any], + scenario_meta: RegisteredScenario, ) -> int: """ Start a scenario run, poll for completion, and print results. @@ -634,7 +640,7 @@ async def _run_scenario_async( scenario_name = parsed_args.scenario_name request = _build_run_request(parsed_args=parsed_args, scenario_name=scenario_name) - total_strategies = len(request.get("strategies") or scenario_meta.get("all_strategies") or []) + total_strategies = len(request.strategies or scenario_meta.all_strategies or []) print(f"\nRunning scenario: {scenario_name}") sys.stdout.flush() @@ -644,7 +650,7 @@ async def _run_scenario_async( print(f"Error starting scenario: {exc}") return 1 - scenario_result_id = run.get("scenario_result_id", "") + scenario_result_id = run.scenario_result_id try: run = await _poll_until_terminal_async( @@ -661,16 +667,16 @@ async def _run_scenario_async( print("Warning: could not cancel scenario run on server.") return 1 - if run.get("status") == "COMPLETED": + if run.status == "COMPLETED": try: detail = await client.get_scenario_run_results_async(scenario_result_id=scenario_result_id) - await _output.print_scenario_result_async(result_dict=detail) + await _output.print_scenario_result_async(result=detail) except Exception: _output.print_scenario_run_summary(run=run) else: _output.print_scenario_run_summary(run=run) - return 0 if run.get("status") == "COMPLETED" else 1 + return 0 if run.status == "COMPLETED" else 1 async def _dispatch_with_client_async(*, client: Any, parsed_args: Namespace) -> int: @@ -695,15 +701,15 @@ async def _dispatch_with_client_async(*, client: Any, parsed_args: Namespace) -> scenario_meta = await client.get_scenario_async(scenario_name=scenario_name) if scenario_meta is None: print(f"Error: Scenario '{scenario_name}' not found on server.") - resp = await client.list_scenarios_async() - names = [s.get("scenario_name", "") for s in resp.get("items", [])] + scenarios = await client.list_scenarios_async() + names = [s.scenario_name for s in scenarios] if names: print(f"Available scenarios: {', '.join(names)}") return 1 reparsed = _reparse_with_scenario_params( parsed_args=parsed_args, - supported_params=scenario_meta.get("supported_parameters") or [], + supported_params=scenario_meta.supported_parameters, ) if reparsed is None: return 1 diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 1a0760eb7c..acac3e1796 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -199,8 +199,8 @@ def do_list_scenarios(self, arg: str) -> None: from pyrit.cli import _output try: - resp = self._run_async(self._api_client.list_scenarios_async()) - _output.print_scenario_list(items=resp.get("items", [])) + scenarios = self._run_async(self._api_client.list_scenarios_async()) + _output.print_scenario_list(items=scenarios) except Exception as e: print(f"Error listing scenarios: {e}") @@ -214,8 +214,8 @@ def do_list_initializers(self, arg: str) -> None: from pyrit.cli import _output try: - resp = self._run_async(self._api_client.list_initializers_async()) - _output.print_initializer_list(items=resp.get("items", [])) + initializers = self._run_async(self._api_client.list_initializers_async()) + _output.print_initializer_list(items=initializers) except Exception as e: print(f"Error listing initializers: {e}") @@ -229,8 +229,8 @@ def do_list_targets(self, arg: str) -> None: from pyrit.cli import _output try: - resp = self._run_async(self._api_client.list_targets_async()) - _output.print_target_list(items=resp.get("items", [])) + targets = self._run_async(self._api_client.list_targets_async()) + _output.print_target_list(items=targets) except Exception as e: print(f"Error listing targets: {e}") @@ -308,6 +308,7 @@ def do_run(self, line: str) -> None: print_scenario_run_progress, print_scenario_run_summary, ) + from pyrit.models.catalog import RunScenarioRequest # Fetch scenario metadata so the parser recognizes scenario-declared flags. scenario_name_token = line.split(maxsplit=1)[0] @@ -319,7 +320,7 @@ def do_run(self, line: str) -> None: if scenario_meta is None: print(f"Error: Scenario '{scenario_name_token}' not found on server.") return - declared_params = build_parameters_from_api(api_params=scenario_meta.get("supported_parameters") or []) + declared_params = build_parameters_from_api(api_params=scenario_meta.supported_parameters) # Parse arguments try: @@ -330,8 +331,8 @@ def do_run(self, line: str) -> None: scenario_name = args["scenario_name"] - # Build request - request: dict[str, Any] = { + # Build typed request + request_kwargs: dict[str, Any] = { "scenario_name": scenario_name, "target_name": args.get("target") or "", } @@ -349,29 +350,31 @@ def do_run(self, line: str) -> None: init_names.append(name) if entry.get("args"): init_args[name] = entry["args"] - request["initializers"] = init_names + request_kwargs["initializers"] = init_names if init_args: - request["initializer_args"] = init_args + request_kwargs["initializer_args"] = init_args if args.get("scenario_strategies"): - request["strategies"] = args["scenario_strategies"] + request_kwargs["strategies"] = args["scenario_strategies"] if args.get("max_concurrency") is not None: - request["max_concurrency"] = args["max_concurrency"] + request_kwargs["max_concurrency"] = args["max_concurrency"] if args.get("max_retries") is not None: - request["max_retries"] = args["max_retries"] + request_kwargs["max_retries"] = args["max_retries"] if args.get("dataset_names"): - request["dataset_names"] = args["dataset_names"] + request_kwargs["dataset_names"] = args["dataset_names"] if args.get("max_dataset_size") is not None: - request["max_dataset_size"] = args["max_dataset_size"] + request_kwargs["max_dataset_size"] = args["max_dataset_size"] if args.get("memory_labels"): - request["labels"] = args["memory_labels"] + request_kwargs["labels"] = args["memory_labels"] scenario_params = extract_scenario_args(parsed=args) if scenario_params: - request["scenario_params"] = scenario_params + request_kwargs["scenario_params"] = scenario_params + + request = RunScenarioRequest(**request_kwargs) # Start run - total_strategies = len(request.get("strategies") or []) + total_strategies = len(request.strategies or []) print(f"\nRunning scenario: {scenario_name}") sys.stdout.flush() @@ -381,7 +384,7 @@ def do_run(self, line: str) -> None: print(f"Error starting scenario: {exc}") return - scenario_result_id = run.get("scenario_result_id", "") + scenario_result_id = run.scenario_result_id # Poll for completion import time @@ -389,9 +392,8 @@ def do_run(self, line: str) -> None: try: while True: run = self._run_async(self._api_client.get_scenario_run_async(scenario_result_id=scenario_result_id)) - status = run.get("status", "UNKNOWN") print_scenario_run_progress(run=run, total_strategies=total_strategies) - if status in self._TERMINAL_STATUSES: + if run.status in self._TERMINAL_STATUSES: break time.sleep(0.5) except KeyboardInterrupt: @@ -405,12 +407,12 @@ def do_run(self, line: str) -> None: return # Print results - if run.get("status") == "COMPLETED": + if run.status == "COMPLETED": try: detail = self._run_async( self._api_client.get_scenario_run_results_async(scenario_result_id=scenario_result_id) ) - self._run_async(print_scenario_result_async(result_dict=detail)) + self._run_async(print_scenario_result_async(result=detail)) except Exception: print_scenario_run_summary(run=run) else: @@ -443,8 +445,8 @@ def do_scenario_history(self, arg: str) -> None: from pyrit.cli._output import print_scenario_runs_list try: - resp = self._run_async(self._api_client.list_scenario_runs_async(limit=limit)) - print_scenario_runs_list(runs=resp.get("items", [])) + runs = self._run_async(self._api_client.list_scenario_runs_async(limit=limit)) + print_scenario_runs_list(runs=runs) except Exception as e: print(f"Error: {e}") @@ -467,7 +469,7 @@ def do_print_scenario(self, arg: str) -> None: try: detail = self._run_async(self._api_client.get_scenario_run_results_async(scenario_result_id=arg)) - self._run_async(print_scenario_result_async(result_dict=detail)) + self._run_async(print_scenario_result_async(result=detail)) except Exception as e: print(f"Error: {e}") diff --git a/pyrit/models/catalog/__init__.py b/pyrit/models/catalog/__init__.py new file mode 100644 index 0000000000..18e6c72454 --- /dev/null +++ b/pyrit/models/catalog/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Catalog sub-package - registry/wire-format types for scenarios, initializers, +and targets that the PyRIT REST API exposes to external clients. + +These models describe canonical PyRIT entities (a registered scenario, a +registered initializer, a runtime target instance, a scenario run summary) +and are imported by both the backend (as response/request payloads) and the +CLI (and any future external REST client). REST framing types (pagination +envelopes, RFC 7807 problem details, GUI-only request bodies) stay in +``pyrit.backend.models``; see the §2.1 rule of thumb in the migration plan. +""" + +from pyrit.models.catalog.initializer import ( + InitializerParameterSummary, + RegisteredInitializer, +) +from pyrit.models.catalog.scenario import ( + RegisteredScenario, + RunScenarioRequest, + ScenarioParameterSummary, + ScenarioRunSummary, +) +from pyrit.models.catalog.target import ( + TargetCapabilitiesInfo, + TargetInstance, +) + +__all__ = [ + "InitializerParameterSummary", + "RegisteredInitializer", + "RegisteredScenario", + "RunScenarioRequest", + "ScenarioParameterSummary", + "ScenarioRunSummary", + "TargetCapabilitiesInfo", + "TargetInstance", +] diff --git a/pyrit/models/catalog/initializer.py b/pyrit/models/catalog/initializer.py new file mode 100644 index 0000000000..abbe70f49d --- /dev/null +++ b/pyrit/models/catalog/initializer.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializer catalog models. + +Initializers configure the PyRIT environment (targets, datasets, env vars) +before scenario execution. These models describe registered-initializer +metadata that both the backend and external REST clients (the CLI today) +consume from ``/api/initializers``. + +Per-field documentation strings (``Field(..., description=...)``) deliberately +live in the backend layer rather than here — see ``pyrit.models.MessagePiece`` +vs ``pyrit.backend.models.attacks.MessagePieceView`` for the same split. +""" + +from pydantic import BaseModel, Field + + +class InitializerParameterSummary(BaseModel): + """Summary of an initializer-declared parameter.""" + + name: str + description: str + default: list[str] | None = None + + +class RegisteredInitializer(BaseModel): + """Summary of a registered initializer.""" + + initializer_name: str + initializer_type: str + description: str = "" + required_env_vars: list[str] = Field(default_factory=list) + supported_parameters: list[InitializerParameterSummary] = Field(default_factory=list) diff --git a/pyrit/models/catalog/scenario.py b/pyrit/models/catalog/scenario.py new file mode 100644 index 0000000000..9cdad4005c --- /dev/null +++ b/pyrit/models/catalog/scenario.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario catalog and run-summary models. + +These describe canonical PyRIT entities exposed over the REST catalog and +scenario-run endpoints; both the backend and external REST clients (the CLI +today) consume them. REST envelopes (pagination, list wrappers) stay in +``pyrit.backend.models``. + +Per-field documentation strings (``Field(..., description=...)``) deliberately +live in the backend layer rather than here — see ``pyrit.models.MessagePiece`` +vs ``pyrit.backend.models.attacks.MessagePieceView`` for the same split. +Validators that affect runtime behavior (``ge``, ``le``) remain on the +canonical models. +""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + +from pyrit.models.scenario_result import ScenarioRunState + + +class ScenarioParameterSummary(BaseModel): + """Summary of a scenario-declared parameter.""" + + name: str + description: str + default: str | None = None + param_type: str + choices: list[str] | None = None + is_list: bool = False + + +class RegisteredScenario(BaseModel): + """Summary of a registered scenario.""" + + scenario_name: str + scenario_type: str + description: str + default_strategy: str + aggregate_strategies: list[str] + all_strategies: list[str] + default_datasets: list[str] + max_dataset_size: int | None = None + supported_parameters: list[ScenarioParameterSummary] = Field(default_factory=list) + + +class RunScenarioRequest(BaseModel): + """Request body for starting a scenario run.""" + + scenario_name: str + target_name: str + initializers: list[str] | None = None + strategies: list[str] | None = None + dataset_names: list[str] | None = None + max_dataset_size: int | None = Field(None, ge=1) + max_concurrency: int = Field(10, ge=1, le=100) + max_retries: int = Field(0, ge=0, le=20) + labels: dict[str, str] | None = None + scenario_params: dict[str, Any] | None = None + initializer_args: dict[str, dict[str, Any]] | None = None + scenario_result_id: str | None = None + + +class ScenarioRunSummary(BaseModel): + """Response for a scenario run (status + result details).""" + + scenario_result_id: str + scenario_name: str + scenario_version: int = Field(0, ge=0) + status: ScenarioRunState + created_at: datetime + updated_at: datetime + error: str | None = None + error_type: str | None = None + strategies_used: list[str] = Field(default_factory=list) + total_attacks: int = Field(0, ge=0) + completed_attacks: int = Field(0, ge=0) + objective_achieved_rate: int = Field(0, ge=0, le=100) + labels: dict[str, str] = Field(default_factory=dict) + completed_at: datetime | None = None diff --git a/pyrit/models/catalog/target.py b/pyrit/models/catalog/target.py new file mode 100644 index 0000000000..fe27072d21 --- /dev/null +++ b/pyrit/models/catalog/target.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target instance catalog models. + +Targets have two concepts: + +- Types: Static metadata bundled with the frontend (from the registry). +- Instances: Runtime objects created via the API with specific configuration. + +The ``TargetInstance`` model is the wire-format snapshot for a runtime +target, used by both the backend (as a REST response payload) and external +REST clients (the CLI today, future external clients tomorrow). + +Per-field documentation strings (``Field(..., description=...)``) deliberately +live in the backend layer rather than here — see ``pyrit.models.MessagePiece`` +vs ``pyrit.backend.models.attacks.MessagePieceView`` for the same split. +""" + +from typing import Any + +from pydantic import BaseModel, Field + + +class TargetCapabilitiesInfo(BaseModel): + """ + Wire-format snapshot of a target's capabilities. + + Mirrors the domain ``TargetCapabilities`` dataclass for API consumers + (notably the GUI). Modality combinations (``frozenset[frozenset[...]]``) + are flattened into sorted unique modality lists since the frontend uses + them only for per-piece modality checks. + """ + + supports_multi_turn: bool = False + supports_multi_message_pieces: bool = False + supports_json_schema: bool = False + supports_json_output: bool = False + supports_editable_history: bool = False + supports_system_prompt: bool = False + supported_input_modalities: list[str] = Field(default_factory=lambda: ["text"]) + supported_output_modalities: list[str] = Field(default_factory=lambda: ["text"]) + + +class TargetInstance(BaseModel): + """ + A runtime target instance. + + Created either by an initializer (at startup) or by user (via API). + Also used as the create-target response (same shape as GET). + """ + + target_registry_name: str + target_type: str + endpoint: str | None = None + model_name: str | None = None + underlying_model_name: str | None = None + temperature: float | None = None + top_p: float | None = None + max_requests_per_minute: int | None = None + capabilities: TargetCapabilitiesInfo + target_specific_params: dict[str, Any] | None = None + inner_targets: list["TargetInstance"] | None = None + identifier_hash: str | None = None diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 29a48bd656..afc21724d2 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -35,12 +35,11 @@ PreviewStep, ) from pyrit.backend.models.targets import ( - TargetCapabilitiesInfo, - TargetInstance, TargetListResponse, ) from pyrit.backend.routes.labels import get_label_options from pyrit.models import MessagePiece +from pyrit.models.catalog.target import TargetCapabilitiesInfo, TargetInstance def _make_message_view(*, role: str = "user", value: str = "hello", sequence: int = 1) -> MessageView: diff --git a/tests/unit/backend/test_initializer_service.py b/tests/unit/backend/test_initializer_service.py index 6f52c5647a..ca6c1988d9 100644 --- a/tests/unit/backend/test_initializer_service.py +++ b/tests/unit/backend/test_initializer_service.py @@ -14,11 +14,13 @@ from pyrit.backend.main import app from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.initializers import ( - InitializerParameterSummary, ListRegisteredInitializersResponse, - RegisteredInitializer, ) from pyrit.backend.services.initializer_service import InitializerService, get_initializer_service +from pyrit.models.catalog.initializer import ( + InitializerParameterSummary, + RegisteredInitializer, +) from pyrit.registry import InitializerMetadata diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index faf40a5b8f..7223b6e01e 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -14,11 +14,9 @@ import pyrit.backend.services.scenario_run_service as _svc_mod from pyrit.backend.main import app -from pyrit.backend.models.scenarios import ( - ScenarioRunListResponse, - ScenarioRunStatus, - ScenarioRunSummary, -) +from pyrit.backend.models.scenarios import ScenarioRunListResponse +from pyrit.models import ScenarioRunState +from pyrit.models.catalog.scenario import ScenarioRunSummary @pytest.fixture @@ -39,7 +37,7 @@ def _mock_run_response( *, run_id: str = "test-run-id", scenario_name: str = "foundry.red_team_agent", - run_status: ScenarioRunStatus = ScenarioRunStatus.CREATED, + run_status: ScenarioRunState = ScenarioRunState.CREATED, ) -> ScenarioRunSummary: """Create a mock ScenarioRunResponse.""" return ScenarioRunSummary( @@ -142,7 +140,7 @@ def test_list_runs_returns_multiple_runs(self, client: TestClient) -> None: """Test that list runs returns all tracked runs.""" runs = [ _mock_run_response(run_id="run-1"), - _mock_run_response(run_id="run-2", run_status=ScenarioRunStatus.IN_PROGRESS), + _mock_run_response(run_id="run-2", run_status=ScenarioRunState.IN_PROGRESS), ] with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: @@ -161,7 +159,7 @@ class TestGetScenarioRunRoute: def test_get_run_returns_200(self, client: TestClient) -> None: """Test that getting an existing run returns 200.""" - mock_response = _mock_run_response(run_status=ScenarioRunStatus.IN_PROGRESS) + mock_response = _mock_run_response(run_status=ScenarioRunState.IN_PROGRESS) with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() @@ -190,7 +188,7 @@ class TestCancelScenarioRunRoute: def test_cancel_run_returns_200(self, client: TestClient) -> None: """Test that cancelling a running scenario returns 200.""" - mock_response = _mock_run_response(run_status=ScenarioRunStatus.CANCELLED) + mock_response = _mock_run_response(run_status=ScenarioRunState.CANCELLED) with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 0a7463d1f0..e990eadc46 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -12,15 +12,12 @@ import pytest import pyrit.backend.services.scenario_run_service as _svc_mod -from pyrit.backend.models.scenarios import ( - RunScenarioRequest, - ScenarioRunStatus, -) from pyrit.backend.services.scenario_run_service import ( _DEFAULT_MAX_CONCURRENT_RUNS, ScenarioRunService, ) -from pyrit.models import AttackOutcome +from pyrit.models import AttackOutcome, ScenarioRunState +from pyrit.models.catalog.scenario import RunScenarioRequest from pyrit.scenario.core import DatasetConfiguration _REGISTRY_PATCH_BASE = "pyrit.registry" @@ -147,7 +144,7 @@ async def test_start_run_returns_running_status(self, mock_all_registries) -> No response = await service.start_run_async(request=_make_request()) assert response.scenario_result_id == "sr-uuid-1" - assert response.status == ScenarioRunStatus.IN_PROGRESS + assert response.status == ScenarioRunState.IN_PROGRESS assert response.scenario_name == "foundry.red_team_agent" assert response.error is None @@ -448,7 +445,7 @@ async def test_start_run_runs_initializers(self, mock_all_registries) -> None: request=_make_request(initializers=["target", "load_default_datasets"]) ) - assert response.status == ScenarioRunStatus.IN_PROGRESS + assert response.status == ScenarioRunState.IN_PROGRESS assert mock_init_instance.initialize_async.await_count == 2 async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: @@ -458,7 +455,7 @@ async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_reg response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) - assert response.status == ScenarioRunStatus.IN_PROGRESS + assert response.status == ScenarioRunState.IN_PROGRESS mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: @@ -492,7 +489,7 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: assert fetched is not None assert fetched.scenario_result_id == "sr-123" assert fetched.scenario_name == "foundry.red_team_agent" - assert fetched.status == ScenarioRunStatus.IN_PROGRESS + assert fetched.status == ScenarioRunState.IN_PROGRESS def test_get_run_falls_back_to_persisted_error(self, mock_memory) -> None: """Test that get_run extracts error from persisted error AttackResult when no active task. @@ -585,7 +582,7 @@ async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> No error_type="CancelledError", ) assert result is not None - assert result.status == ScenarioRunStatus.CANCELLED + assert result.status == ScenarioRunState.CANCELLED async def test_cancel_completed_run_raises_value_error(self, mock_memory) -> None: """Test that cancelling a completed run raises ValueError.""" @@ -738,7 +735,7 @@ def test_in_progress_run_shows_partial_attack_counts(self, mock_memory) -> None: fetched = service.get_run(scenario_result_id="sr-running") assert fetched is not None - assert fetched.status == ScenarioRunStatus.IN_PROGRESS + assert fetched.status == ScenarioRunState.IN_PROGRESS assert fetched.total_attacks == 3 assert fetched.completed_attacks == 3 assert fetched.strategies_used == ["attack_a", "attack_b"] @@ -757,7 +754,7 @@ def test_created_run_shows_zero_counts(self, mock_memory) -> None: fetched = service.get_run(scenario_result_id="sr-new") assert fetched is not None - assert fetched.status == ScenarioRunStatus.CREATED + assert fetched.status == ScenarioRunState.CREATED assert fetched.total_attacks == 0 assert fetched.completed_attacks == 0 assert fetched.strategies_used == [] @@ -782,7 +779,7 @@ def test_completed_run_still_shows_full_counts(self, mock_memory) -> None: fetched = service.get_run(scenario_result_id="sr-done") assert fetched is not None - assert fetched.status == ScenarioRunStatus.COMPLETED + assert fetched.status == ScenarioRunState.COMPLETED assert fetched.total_attacks == 1 assert fetched.completed_attacks == 1 assert fetched.strategies_used == ["attack_a"] diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 6aab39f0f9..262d0e582a 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -13,8 +13,9 @@ from pyrit.backend.main import app from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse, RegisteredScenario +from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service +from pyrit.models.catalog.scenario import RegisteredScenario from pyrit.registry import ScenarioMetadata from pyrit.registry.class_registries.scenario_registry import ScenarioParameterMetadata diff --git a/tests/unit/cli/test_api_client.py b/tests/unit/cli/test_api_client.py index 95c232273e..ebeb54e7d8 100644 --- a/tests/unit/cli/test_api_client.py +++ b/tests/unit/cli/test_api_client.py @@ -5,12 +5,22 @@ Unit tests for pyrit.cli.api_client.PyRITApiClient. """ +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from pyrit.cli.api_client import PyRITApiClient, ServerNotAvailableError +from pyrit.models import ScenarioRunState +from pyrit.models.catalog import ( + RegisteredInitializer, + RegisteredScenario, + RunScenarioRequest, + ScenarioRunSummary, + TargetCapabilitiesInfo, + TargetInstance, +) @pytest.fixture() @@ -39,6 +49,68 @@ def _make_response(*, status_code=200, json_data=None): return resp +def _scenario_payload(*, scenario_name: str = "s1") -> dict: + """Build a wire-format ``RegisteredScenario`` payload.""" + return { + "scenario_name": scenario_name, + "scenario_type": "RedTeamAgentScenario", + "description": "test scenario", + "default_strategy": "single_turn", + "aggregate_strategies": [], + "all_strategies": ["single_turn"], + "default_datasets": [], + "max_dataset_size": None, + "supported_parameters": [], + } + + +def _initializer_payload(*, initializer_name: str = "x") -> dict: + return { + "initializer_name": initializer_name, + "initializer_type": "TargetInitializer", + "description": "", + "required_env_vars": [], + "supported_parameters": [], + } + + +def _target_payload(*, target_registry_name: str = "t1") -> dict: + return { + "target_registry_name": target_registry_name, + "target_type": "OpenAIChatTarget", + "endpoint": None, + "model_name": None, + "underlying_model_name": None, + "temperature": None, + "top_p": None, + "max_requests_per_minute": None, + "capabilities": TargetCapabilitiesInfo().model_dump(mode="json"), + "target_specific_params": None, + "inner_targets": None, + "identifier_hash": None, + } + + +def _run_summary_payload(*, scenario_result_id: str = "abc", status: str = "CREATED") -> dict: + now = datetime(2025, 1, 1, tzinfo=timezone.utc).isoformat() + return { + "scenario_result_id": scenario_result_id, + "scenario_name": "x", + "scenario_version": 0, + "status": status, + "created_at": now, + "updated_at": now, + "error": None, + "error_type": None, + "strategies_used": [], + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "labels": {}, + "completed_at": None, + } + + # --------------------------------------------------------------------------- # Init / context manager / lifecycle # --------------------------------------------------------------------------- @@ -124,18 +196,20 @@ async def test_health_check_returns_false_on_generic_exception(client, mock_http async def test_list_scenarios_async(client, mock_httpx_client): - payload = {"items": [{"scenario_name": "s1"}], "pagination": {}} + payload = {"items": [_scenario_payload(scenario_name="s1")], "pagination": {}} mock_httpx_client.get.return_value = _make_response(json_data=payload) result = await client.list_scenarios_async(limit=10) - assert result == payload + assert len(result) == 1 + assert isinstance(result[0], RegisteredScenario) + assert result[0].scenario_name == "s1" mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/catalog", params={"limit": 10}) async def test_get_scenario_async_returns_payload(client, mock_httpx_client): - payload = {"scenario_name": "foo"} - mock_httpx_client.get.return_value = _make_response(json_data=payload) + mock_httpx_client.get.return_value = _make_response(json_data=_scenario_payload(scenario_name="foo")) result = await client.get_scenario_async(scenario_name="foo") - assert result == payload + assert isinstance(result, RegisteredScenario) + assert result.scenario_name == "foo" mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/catalog/foo", params=None) @@ -163,16 +237,19 @@ async def test_get_scenario_async_raises_on_other_http_errors(client, mock_httpx async def test_list_initializers_async(client, mock_httpx_client): - mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) - await client.list_initializers_async(limit=5) + mock_httpx_client.get.return_value = _make_response(json_data={"items": [_initializer_payload()]}) + result = await client.list_initializers_async(limit=5) + assert len(result) == 1 + assert isinstance(result[0], RegisteredInitializer) mock_httpx_client.get.assert_awaited_once_with("/api/initializers", params={"limit": 5}) async def test_register_initializer_async_success(client, mock_httpx_client): - payload = {"initializer_name": "x"} + payload = _initializer_payload(initializer_name="x") mock_httpx_client.post.return_value = _make_response(json_data=payload) result = await client.register_initializer_async(name="x", script_content="print(1)") - assert result == payload + assert isinstance(result, RegisteredInitializer) + assert result.initializer_name == "x" mock_httpx_client.post.assert_awaited_once_with( "/api/initializers", json={"name": "x", "script_content": "print(1)"} ) @@ -199,8 +276,10 @@ async def test_register_initializer_async_raises_on_500(client, mock_httpx_clien async def test_list_targets_async(client, mock_httpx_client): - mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) - await client.list_targets_async(limit=7) + mock_httpx_client.get.return_value = _make_response(json_data={"items": [_target_payload()]}) + result = await client.list_targets_async(limit=7) + assert len(result) == 1 + assert isinstance(result[0], TargetInstance) mock_httpx_client.get.assert_awaited_once_with("/api/targets", params={"limit": 7}) @@ -210,20 +289,28 @@ async def test_list_targets_async(client, mock_httpx_client): async def test_start_scenario_run_async(client, mock_httpx_client): - payload = {"scenario_result_id": "abc"} - mock_httpx_client.post.return_value = _make_response(json_data=payload) - request = {"scenario_name": "x"} + mock_httpx_client.post.return_value = _make_response(json_data=_run_summary_payload(scenario_result_id="abc")) + request = RunScenarioRequest(scenario_name="x", target_name="t") result = await client.start_scenario_run_async(request=request) - assert result == payload - mock_httpx_client.post.assert_awaited_once_with("/api/scenarios/runs", json=request) + assert isinstance(result, ScenarioRunSummary) + assert result.scenario_result_id == "abc" + mock_httpx_client.post.assert_awaited_once() + args, kwargs = mock_httpx_client.post.call_args + assert args == ("/api/scenarios/runs",) + # The CLI serializes the typed request via model_dump(mode="json", exclude_none=True); + # required fields must appear in the body, None-valued fields must not. + assert kwargs["json"]["scenario_name"] == "x" + assert kwargs["json"]["target_name"] == "t" + assert "scenario_params" not in kwargs["json"] async def test_get_scenario_run_async(client, mock_httpx_client): import httpx as _httpx - mock_httpx_client.get.return_value = _make_response(json_data={"status": "RUNNING"}) + mock_httpx_client.get.return_value = _make_response(json_data=_run_summary_payload(status="IN_PROGRESS")) result = await client.get_scenario_run_async(scenario_result_id="abc") - assert result == {"status": "RUNNING"} + assert isinstance(result, ScenarioRunSummary) + assert result.status == ScenarioRunState.IN_PROGRESS # Polling uses read=None so a busy server doesn't trip the client default # timeout while a scenario is executing. mock_httpx_client.get.assert_awaited_once() @@ -243,22 +330,37 @@ async def test_get_scenario_run_async_wraps_connect_error(client, mock_httpx_cli async def test_get_scenario_run_results_async(client, mock_httpx_client): - mock_httpx_client.get.return_value = _make_response(json_data={"run": {}, "attacks": []}) + # Build a minimal ScenarioResult.to_dict() payload that from_dict can deserialize. + from pyrit.models import ScenarioIdentifier, ScenarioResult, ScenarioRunState + + scenario_result = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="x"), + objective_target_identifier=None, + objective_scorer_identifier=None, + attack_results={}, + scenario_run_state=ScenarioRunState.COMPLETED, + ) + mock_httpx_client.get.return_value = _make_response( + json_data=scenario_result.model_dump(mode="json", by_alias=True) + ) result = await client.get_scenario_run_results_async(scenario_result_id="abc") - assert "run" in result + assert isinstance(result, ScenarioResult) mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/runs/abc/results", params=None) async def test_cancel_scenario_run_async(client, mock_httpx_client): - mock_httpx_client.post.return_value = _make_response(json_data={"status": "CANCELLED"}) + mock_httpx_client.post.return_value = _make_response(json_data=_run_summary_payload(status="CANCELLED")) result = await client.cancel_scenario_run_async(scenario_result_id="abc") - assert result == {"status": "CANCELLED"} + assert isinstance(result, ScenarioRunSummary) + assert result.status == ScenarioRunState.CANCELLED mock_httpx_client.post.assert_awaited_once_with("/api/scenarios/runs/abc/cancel") async def test_list_scenario_runs_async(client, mock_httpx_client): - mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) - await client.list_scenario_runs_async(limit=20) + mock_httpx_client.get.return_value = _make_response(json_data={"items": [_run_summary_payload()]}) + result = await client.list_scenario_runs_async(limit=20) + assert len(result) == 1 + assert isinstance(result[0], ScenarioRunSummary) mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/runs", params={"limit": 20}) diff --git a/tests/unit/cli/test_import_guards.py b/tests/unit/cli/test_import_guards.py index 3b87b2d442..cf5b062165 100644 --- a/tests/unit/cli/test_import_guards.py +++ b/tests/unit/cli/test_import_guards.py @@ -44,7 +44,12 @@ def _check_forbidden_imports(*, import_statement: str, forbidden: list[str]) -> # Heavy modules that should never be loaded during CLI arg parsing. -# This ensures `pyrit_scan --help` stays near-instant (~0.3s). +# This ensures `pyrit_scan --help` stays near-instant (~0.4s). +# +# ``pyrit.models`` and ``pyrit.backend`` are the real cost we're guarding +# against — eagerly importing ``pyrit.models`` adds ~500ms to ``--help``. +# ``pydantic`` is kept as cheap insurance against eager BaseModel class +# compilation creeping into the bootstrap path. _CLI_FORBIDDEN = [ "alembic", "av", @@ -54,6 +59,8 @@ def _check_forbidden_imports(*, import_statement: str, forbidden: list[str]) -> "openai", "pandas", "pydantic", + "pyrit.backend", + "pyrit.models", "scipy", "sqlalchemy", "torch", diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py index 26f41cccc2..69f9654563 100644 --- a/tests/unit/cli/test_output.py +++ b/tests/unit/cli/test_output.py @@ -3,11 +3,100 @@ """ Unit tests for pyrit.cli._output formatting helpers. + +All public ``print_*`` functions accept typed ``pyrit.models`` objects +(``RegisteredScenario``, ``RegisteredInitializer``, ``TargetInstance``, +``ScenarioRunSummary``, ``ScenarioResult``). """ +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch from pyrit.cli import _output +from pyrit.models import ScenarioRunState +from pyrit.models.catalog import ( + InitializerParameterSummary, + RegisteredInitializer, + RegisteredScenario, + ScenarioParameterSummary, + ScenarioRunSummary, + TargetCapabilitiesInfo, + TargetInstance, +) + +# --------------------------------------------------------------------------- +# Typed-object factory helpers +# --------------------------------------------------------------------------- + + +def _make_scenario(**overrides) -> RegisteredScenario: + defaults = { + "scenario_name": "s1", + "scenario_type": "X", + "description": "", + "default_strategy": "", + "aggregate_strategies": [], + "all_strategies": [], + "default_datasets": [], + "max_dataset_size": None, + "supported_parameters": [], + } + defaults.update(overrides) + return RegisteredScenario(**defaults) + + +def _make_initializer(**overrides) -> RegisteredInitializer: + defaults = { + "initializer_name": "i1", + "initializer_type": "T", + "description": "", + "required_env_vars": [], + "supported_parameters": [], + } + defaults.update(overrides) + return RegisteredInitializer(**defaults) + + +def _make_target(**overrides) -> TargetInstance: + defaults = { + "target_registry_name": "t1", + "target_type": "X", + "endpoint": None, + "model_name": None, + "underlying_model_name": None, + "temperature": None, + "top_p": None, + "max_requests_per_minute": None, + "capabilities": TargetCapabilitiesInfo(), + "target_specific_params": None, + "inner_targets": None, + "identifier_hash": None, + } + defaults.update(overrides) + return TargetInstance(**defaults) + + +def _make_run(**overrides) -> ScenarioRunSummary: + now = datetime(2025, 1, 1, tzinfo=timezone.utc) + defaults = { + "scenario_result_id": "abc-123", + "scenario_name": "test_sc", + "scenario_version": 0, + "status": ScenarioRunState.CREATED, + "created_at": now, + "updated_at": now, + "error": None, + "error_type": None, + "strategies_used": [], + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "labels": {}, + "completed_at": None, + } + defaults.update(overrides) + return ScenarioRunSummary(**defaults) + # --------------------------------------------------------------------------- # Internal helpers @@ -74,32 +163,32 @@ def test_print_scenario_list_empty(capsys): def test_print_scenario_list_full(capsys): items = [ - { - "scenario_name": "airt.scam", - "scenario_type": "ScamScenario", - "description": "A test scenario.", - "aggregate_strategies": ["single_turn"], - "all_strategies": ["s1", "s2", "s3"], - "default_strategy": "s1", - "default_datasets": ["d1", "d2"], - "max_dataset_size": 50, - "supported_parameters": [ - { - "name": "max_turns", - "default": 5, - "param_type": "int", - "choices": None, - "description": "Maximum turns.", - }, - { - "name": "mode", - "default": None, - "param_type": "str", - "choices": ["a", "b"], - "description": "Mode.", - }, + _make_scenario( + scenario_name="airt.scam", + scenario_type="ScamScenario", + description="A test scenario.", + aggregate_strategies=["single_turn"], + all_strategies=["s1", "s2", "s3"], + default_strategy="s1", + default_datasets=["d1", "d2"], + max_dataset_size=50, + supported_parameters=[ + ScenarioParameterSummary( + name="max_turns", + default="5", + param_type="int", + choices=None, + description="Maximum turns.", + ), + ScenarioParameterSummary( + name="mode", + default=None, + param_type="str", + choices=["a", "b"], + description="Mode.", + ), ], - } + ) ] _output.print_scenario_list(items=items) captured = capsys.readouterr() @@ -118,7 +207,7 @@ def test_print_scenario_list_full(capsys): def test_print_scenario_list_minimal_fields(capsys): - items = [{"scenario_name": "min", "scenario_type": "MinScenario"}] + items = [_make_scenario(scenario_name="min", scenario_type="MinScenario")] _output.print_scenario_list(items=items) captured = capsys.readouterr() assert "min" in captured.out @@ -127,11 +216,11 @@ def test_print_scenario_list_minimal_fields(capsys): def test_print_scenario_list_no_max_dataset_size(capsys): items = [ - { - "scenario_name": "no_max", - "scenario_type": "T", - "default_datasets": ["d1"], - } + _make_scenario( + scenario_name="no_max", + scenario_type="T", + default_datasets=["d1"], + ) ] _output.print_scenario_list(items=items) captured = capsys.readouterr() @@ -152,21 +241,21 @@ def test_print_initializer_list_empty(capsys): def test_print_initializer_list_full(capsys): items = [ - { - "initializer_name": "openai_target", - "initializer_type": "OpenAITargetInitializer", - "required_env_vars": ["OPENAI_API_KEY", "OPENAI_ENDPOINT"], - "supported_parameters": [ - {"name": "model", "default": "gpt-4", "description": "Model name."}, - {"name": "temp", "default": None, "description": "Temperature."}, + _make_initializer( + initializer_name="openai_target", + initializer_type="OpenAITargetInitializer", + required_env_vars=["OPENAI_API_KEY", "OPENAI_ENDPOINT"], + supported_parameters=[ + InitializerParameterSummary(name="model", default=["gpt-4"], description="Model name."), + InitializerParameterSummary(name="temp", default=None, description="Temperature."), ], - "description": "Registers OpenAI targets.", - }, - { - "initializer_name": "no_env", - "initializer_type": "NoEnvInitializer", - "required_env_vars": [], - }, + description="Registers OpenAI targets.", + ), + _make_initializer( + initializer_name="no_env", + initializer_type="NoEnvInitializer", + required_env_vars=[], + ), ] _output.print_initializer_list(items=items) captured = capsys.readouterr() @@ -193,21 +282,21 @@ def test_print_target_list_empty(capsys): def test_print_target_list_full(capsys): items = [ - { - "target_registry_name": "openai_chat", - "target_type": "OpenAIChatTarget", - "underlying_model_name": "gpt-4", - "endpoint": "https://example.com", - }, - { - "target_registry_name": "claude", - "target_type": "AnthropicTarget", - "model_name": "claude-sonnet", - }, - { - "target_registry_name": "minimal", - "target_type": "MinimalTarget", - }, + _make_target( + target_registry_name="openai_chat", + target_type="OpenAIChatTarget", + underlying_model_name="gpt-4", + endpoint="https://example.com", + ), + _make_target( + target_registry_name="claude", + target_type="AnthropicTarget", + model_name="claude-sonnet", + ), + _make_target( + target_registry_name="minimal", + target_type="MinimalTarget", + ), ] _output.print_target_list(items=items) captured = capsys.readouterr() @@ -225,43 +314,43 @@ def test_print_target_list_full(capsys): def test_print_scenario_run_progress_with_known_totals(capsys): - run = { - "status": "RUNNING", - "total_attacks": 10, - "completed_attacks": 5, - "objective_achieved_rate": 30, - "strategies_used": ["s1", "s2"], - } + run = _make_run( + status=ScenarioRunState.IN_PROGRESS, + total_attacks=10, + completed_attacks=5, + objective_achieved_rate=30, + strategies_used=["s1", "s2"], + ) _output.print_scenario_run_progress(run=run, total_strategies=4) captured = capsys.readouterr() assert "strategies: 2/4" in captured.out assert "5/10" in captured.out - assert "RUNNING" in captured.out + assert "IN_PROGRESS" in captured.out assert "30%" in captured.out def test_print_scenario_run_progress_no_total_attacks(capsys): - run = { - "status": "PENDING", - "total_attacks": 0, - "completed_attacks": 0, - "objective_achieved_rate": 0, - "strategies_used": [], - } + run = _make_run( + status=ScenarioRunState.CREATED, + total_attacks=0, + completed_attacks=0, + objective_achieved_rate=0, + strategies_used=[], + ) _output.print_scenario_run_progress(run=run, total_strategies=0) captured = capsys.readouterr() assert "attacks: 0" in captured.out - assert "PENDING" in captured.out + assert "CREATED" in captured.out def test_print_scenario_run_progress_strategies_done_only(capsys): - run = { - "status": "RUNNING", - "total_attacks": 0, - "completed_attacks": 0, - "objective_achieved_rate": 0, - "strategies_used": ["s1"], - } + run = _make_run( + status=ScenarioRunState.IN_PROGRESS, + total_attacks=0, + completed_attacks=0, + objective_achieved_rate=0, + strategies_used=["s1"], + ) _output.print_scenario_run_progress(run=run, total_strategies=0) captured = capsys.readouterr() assert "strategies: 1" in captured.out @@ -273,15 +362,15 @@ def test_print_scenario_run_progress_strategies_done_only(capsys): def test_print_scenario_run_summary_completed(capsys): - run = { - "scenario_name": "test_sc", - "scenario_result_id": "abc-123", - "status": "COMPLETED", - "total_attacks": 5, - "completed_attacks": 5, - "objective_achieved_rate": 40, - "strategies_used": ["s1", "s2"], - } + run = _make_run( + scenario_name="test_sc", + scenario_result_id="abc-123", + status=ScenarioRunState.COMPLETED, + total_attacks=5, + completed_attacks=5, + objective_achieved_rate=40, + strategies_used=["s1", "s2"], + ) _output.print_scenario_run_summary(run=run) captured = capsys.readouterr() assert "test_sc" in captured.out @@ -292,15 +381,15 @@ def test_print_scenario_run_summary_completed(capsys): def test_print_scenario_run_summary_with_error(capsys): - run = { - "scenario_name": "failing", - "scenario_result_id": "id", - "status": "FAILED", - "total_attacks": 0, - "completed_attacks": 0, - "objective_achieved_rate": 0, - "error": "boom", - } + run = _make_run( + scenario_name="failing", + scenario_result_id="id", + status=ScenarioRunState.FAILED, + total_attacks=0, + completed_attacks=0, + objective_achieved_rate=0, + error="boom", + ) _output.print_scenario_run_summary(run=run) captured = capsys.readouterr() assert "Error:" in captured.out @@ -313,37 +402,28 @@ def test_print_scenario_run_summary_with_error(capsys): async def test_print_scenario_result_async_uses_pretty_printer(): - result_dict = {"some": "data"} + """``print_scenario_result_async`` hands the typed ``ScenarioResult`` to the pretty printer.""" fake_scenario = MagicMock() fake_printer = MagicMock() fake_printer.write_async = AsyncMock() - with ( - patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, - patch( - "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer - ) as printer_cls, - ): - await _output.print_scenario_result_async(result_dict=result_dict) + with patch( + "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", + return_value=fake_printer, + ) as printer_cls: + await _output.print_scenario_result_async(result=fake_scenario) - from_dict_mock.assert_called_once_with(result_dict) printer_cls.assert_called_once_with() fake_printer.write_async.assert_awaited_once_with(fake_scenario) -async def test_print_scenario_result_async_roundtrip_with_real_payload(): - """ - Integration smoke test: a real ScenarioResult.to_dict() payload must flow - through ScenarioResult.from_dict() inside print_scenario_result_async - without raising. Locks the REST contract used by the CLI thin client. - """ - from datetime import datetime, timezone - +async def test_print_scenario_result_async_accepts_real_scenario_result(): + """A real ``ScenarioResult`` instance flows through ``print_scenario_result_async``.""" from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult identifier = ScenarioIdentifier(name="test.scenario", description="A test") - target_identifier = ComponentIdentifier.from_dict( + target_identifier = ComponentIdentifier.model_validate( {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} ) attack = AttackResult( @@ -354,31 +434,23 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload(): execution_time_ms=150, timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), ) - original = ScenarioResult( + scenario_result = ScenarioResult( scenario_identifier=identifier, objective_target_identifier=target_identifier, objective_scorer_identifier=None, attack_results={"strat_a": [attack]}, - scenario_run_state="COMPLETED", + scenario_run_state=ScenarioRunState.COMPLETED, ) - payload = original.to_dict() - # Drive print_scenario_result_async through the real from_dict path; only - # stub the printer to keep the test fast. fake_printer = MagicMock() fake_printer.write_async = AsyncMock() with patch( "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer, ): - await _output.print_scenario_result_async(result_dict=payload) + await _output.print_scenario_result_async(result=scenario_result) - fake_printer.write_async.assert_awaited_once() - reconstructed = fake_printer.write_async.await_args.args[0] - assert isinstance(reconstructed, ScenarioResult) - assert reconstructed.scenario_identifier.name == "test.scenario" - assert list(reconstructed.attack_results.keys()) == ["strat_a"] - assert reconstructed.attack_results["strat_a"][0].outcome == AttackOutcome.SUCCESS + fake_printer.write_async.assert_awaited_once_with(scenario_result) # --------------------------------------------------------------------------- @@ -394,22 +466,22 @@ def test_print_scenario_runs_list_empty(capsys): def test_print_scenario_runs_list_populated(capsys): runs = [ - { - "status": "COMPLETED", - "scenario_name": "scen-a", - "scenario_result_id": "abcdefgh1234", - "total_attacks": 4, - "objective_achieved_rate": 75, - "created_at": "2024-01-01", - }, - { - "status": "RUNNING", - "scenario_name": "scen-b", - "scenario_result_id": "ijklmnop5678", - "total_attacks": 0, - "objective_achieved_rate": 0, - "created_at": "2024-02-02", - }, + _make_run( + status=ScenarioRunState.COMPLETED, + scenario_name="scen-a", + scenario_result_id="abcdefgh1234", + total_attacks=4, + objective_achieved_rate=75, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + _make_run( + status=ScenarioRunState.IN_PROGRESS, + scenario_name="scen-b", + scenario_result_id="ijklmnop5678", + total_attacks=0, + objective_achieved_rate=0, + created_at=datetime(2024, 2, 2, tzinfo=timezone.utc), + ), ] _output.print_scenario_runs_list(runs=runs) captured = capsys.readouterr() diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index d204e4433c..e14f1ba36a 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -150,41 +150,66 @@ def test_scenario_keys_extracted_with_prefix_stripped(self): def _mock_api_client(): - """Create a mock PyRITApiClient with default response behaviors.""" + """Create a mock PyRITApiClient with default response behaviors (typed wire-data).""" + from datetime import datetime, timezone + + from pyrit.models import ScenarioRunState + from pyrit.models.catalog import ( + RegisteredScenario, + ScenarioRunSummary, + TargetCapabilitiesInfo, + TargetInstance, + ) + + now = datetime(2025, 1, 1, tzinfo=timezone.utc) + client = AsyncMock() client.health_check_async.return_value = True - client.list_scenarios_async.return_value = {"items": [], "pagination": {"total": 0}} - client.list_initializers_async.return_value = {"items": [], "pagination": {"total": 0}} - client.list_targets_async.return_value = {"items": [], "pagination": {"total": 0}} - client.get_scenario_async.return_value = { - "scenario_name": "test_scenario", - "supported_parameters": [], - } - client.start_scenario_run_async.return_value = { - "scenario_result_id": "test-id-123", - "scenario_name": "test_scenario", - "status": "CREATED", - } - client.get_scenario_run_async.return_value = { - "scenario_result_id": "test-id-123", - "status": "COMPLETED", - "total_attacks": 5, - "completed_attacks": 5, - "objective_achieved_rate": 40, - } - client.get_scenario_run_results_async.return_value = { - "run": { - "scenario_result_id": "test-id-123", - "scenario_name": "test_scenario", - "status": "COMPLETED", - "total_attacks": 5, - "completed_attacks": 5, - "objective_achieved_rate": 40, - }, - "attacks": [], - } + client.list_scenarios_async.return_value = [] + client.list_initializers_async.return_value = [] + client.list_targets_async.return_value = [] + client.get_scenario_async.return_value = RegisteredScenario( + scenario_name="test_scenario", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + supported_parameters=[], + ) + client.start_scenario_run_async.return_value = ScenarioRunSummary( + scenario_result_id="test-id-123", + scenario_name="test_scenario", + scenario_version=0, + status=ScenarioRunState.CREATED, + created_at=now, + updated_at=now, + strategies_used=[], + total_attacks=0, + completed_attacks=0, + objective_achieved_rate=0, + ) + client.get_scenario_run_async.return_value = ScenarioRunSummary( + scenario_result_id="test-id-123", + scenario_name="test_scenario", + scenario_version=0, + status=ScenarioRunState.COMPLETED, + created_at=now, + updated_at=now, + strategies_used=[], + total_attacks=5, + completed_attacks=5, + objective_achieved_rate=40, + ) + # get_scenario_run_results_async returns ScenarioResult; tests that need + # a real one patch this individually. Default: raise so callers must opt-in. + client.get_scenario_run_results_async.side_effect = RuntimeError("results not configured") client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=None) + # Marker so tests that re-shape the mock can find the unused TargetInstance helper. + _ = (TargetCapabilitiesInfo, TargetInstance) return client @@ -257,7 +282,7 @@ def test_main_run_scenario_with_initializers(self, mock_client_class, mock_probe assert result == 0 call_kwargs = mock_client.start_scenario_run_async.call_args.kwargs request = call_kwargs["request"] - assert request["initializers"] == ["target", "datasets"] + assert request.initializers == ["target", "datasets"] @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=False) def test_main_server_not_available(self, mock_probe, capsys): @@ -295,15 +320,25 @@ def test_main_scenario_not_found(self, mock_client_class, mock_probe, capsys): @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_failed_scenario(self, mock_client_class, mock_probe): """Test main when scenario run fails.""" + from datetime import datetime, timezone + + from pyrit.models import ScenarioRunState + from pyrit.models.catalog import ScenarioRunSummary + + now = datetime(2025, 1, 1, tzinfo=timezone.utc) mock_client = _mock_api_client() - mock_client.get_scenario_run_async.return_value = { - "scenario_result_id": "test-id", - "status": "FAILED", - "total_attacks": 0, - "completed_attacks": 0, - "objective_achieved_rate": 0, - "error": "Something went wrong", - } + mock_client.get_scenario_run_async.return_value = ScenarioRunSummary( + scenario_result_id="test-id", + scenario_name="test_scenario", + scenario_version=0, + status=ScenarioRunState.FAILED, + created_at=now, + updated_at=now, + total_attacks=0, + completed_attacks=0, + objective_achieved_rate=0, + error="Something went wrong", + ) mock_client_class.return_value = mock_client result = pyrit_scan.main(["test_scenario", "--target", "t"]) @@ -390,12 +425,14 @@ class TestAddScenarioParamsFromApi: def test_adds_unseen_params_as_optional_flags(self): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() pyrit_scan._add_scenario_params_from_api( parser=parser, params=[ - {"name": "max_turns", "description": "Max turns."}, - {"name": "mode", "description": "Mode."}, + ScenarioParameterSummary(name="max_turns", description="Max turns.", param_type="str"), + ScenarioParameterSummary(name="mode", description="Mode.", param_type="str"), ], ) parsed = parser.parse_args(["--max-turns", "5", "--mode", "fast"]) @@ -405,11 +442,13 @@ def test_adds_unseen_params_as_optional_flags(self): def test_skips_params_that_collide_with_existing_flags(self): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() parser.add_argument("--target") pyrit_scan._add_scenario_params_from_api( parser=parser, - params=[{"name": "target", "description": "..."}], + params=[ScenarioParameterSummary(name="target", description="...", param_type="str")], ) parsed = parser.parse_args(["--target", "x"]) # Original --target wins; no scenario__target added. @@ -432,8 +471,8 @@ def test_includes_initializer_args(self): memory_labels=None, ) request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") - assert request["initializers"] == ["openai_target", "datasets"] - assert request["initializer_args"] == {"openai_target": {"model": "gpt-4"}} + assert request.initializers == ["openai_target", "datasets"] + assert request.initializer_args == {"openai_target": {"model": "gpt-4"}} def test_populates_optional_fields(self): parsed = Namespace( @@ -447,12 +486,12 @@ def test_populates_optional_fields(self): memory_labels='{"key":"value"}', ) request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") - assert request["strategies"] == ["s1"] - assert request["max_concurrency"] == 3 - assert request["max_retries"] == 2 - assert request["dataset_names"] == ["d1"] - assert request["max_dataset_size"] == 10 - assert request["labels"] == {"key": "value"} + assert request.strategies == ["s1"] + assert request.max_concurrency == 3 + assert request.max_retries == 2 + assert request.dataset_names == ["d1"] + assert request.max_dataset_size == 10 + assert request.labels == {"key": "value"} def test_includes_scenario_declared_params(self): parsed = Namespace( @@ -467,7 +506,7 @@ def test_includes_scenario_declared_params(self): scenario__max_turns="7", ) request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") - assert request["scenario_params"] == {"max_turns": "7"} + assert request.scenario_params == {"max_turns": "7"} class TestResolveServerUrl: @@ -603,10 +642,12 @@ class TestScenarioParamCoercion: def test_list_param_uses_nargs_plus(self): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() pyrit_scan._add_scenario_params_from_api( parser=parser, - params=[{"name": "items", "description": "...", "param_type": "list[str]", "is_list": True}], + params=[ScenarioParameterSummary(name="items", description="...", param_type="list[str]", is_list=True)], ) parsed = parser.parse_args(["--items", "a", "b", "c"]) assert parsed.scenario__items == ["a", "b", "c"] @@ -614,10 +655,12 @@ def test_list_param_uses_nargs_plus(self): def test_int_param_is_coerced(self): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() pyrit_scan._add_scenario_params_from_api( parser=parser, - params=[{"name": "max_turns", "description": "...", "param_type": "int"}], + params=[ScenarioParameterSummary(name="max_turns", description="...", param_type="int")], ) parsed = parser.parse_args(["--max-turns", "7"]) assert parsed.scenario__max_turns == 7 @@ -625,10 +668,12 @@ def test_int_param_is_coerced(self): def test_int_param_invalid_value_rejected_client_side(self, capsys): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() pyrit_scan._add_scenario_params_from_api( parser=parser, - params=[{"name": "max_turns", "description": "...", "param_type": "int"}], + params=[ScenarioParameterSummary(name="max_turns", description="...", param_type="int")], ) with pytest.raises(SystemExit): parser.parse_args(["--max-turns", "not-an-int"]) @@ -637,10 +682,14 @@ def test_int_param_invalid_value_rejected_client_side(self, capsys): def test_choices_validated_client_side(self, capsys): from argparse import ArgumentParser + from pyrit.models.catalog import ScenarioParameterSummary + parser = ArgumentParser() pyrit_scan._add_scenario_params_from_api( parser=parser, - params=[{"name": "mode", "description": "...", "param_type": "str", "choices": ["fast", "slow"]}], + params=[ + ScenarioParameterSummary(name="mode", description="...", param_type="str", choices=["fast", "slow"]) + ], ) parsed = parser.parse_args(["--mode", "fast"]) assert parsed.scenario__mode == "fast" @@ -662,12 +711,32 @@ def test_main_no_args_prints_help_and_exits_zero(self, capsys): @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_probe, capsys): + from pyrit.models.catalog import RegisteredScenario + mock_client = _mock_api_client() mock_client.get_scenario_async.return_value = None - mock_client.list_scenarios_async.return_value = { - "items": [{"scenario_name": "alt_a"}, {"scenario_name": "alt_b"}], - "pagination": {}, - } + mock_client.list_scenarios_async.return_value = [ + RegisteredScenario( + scenario_name="alt_a", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + ), + RegisteredScenario( + scenario_name="alt_b", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + ), + ] mock_client_class.return_value = mock_client result = pyrit_scan.main(["nonexistent", "--target", "t"]) @@ -780,17 +849,75 @@ class TestScenarioParamFlow: @staticmethod def _build_mock_client(supported_params=None, status="COMPLETED"): + from datetime import datetime, timezone from unittest.mock import AsyncMock + from pyrit.models import ScenarioRunState + from pyrit.models.catalog import ( + RegisteredScenario, + ScenarioParameterSummary, + ScenarioRunSummary, + ) + + now = datetime(2025, 1, 1, tzinfo=timezone.utc) + typed_params: list[ScenarioParameterSummary] = [] + for p in supported_params or []: + if isinstance(p, ScenarioParameterSummary): + typed_params.append(p) + else: + typed_params.append( + ScenarioParameterSummary( + name=p["name"], + description=p.get("description", ""), + default=p.get("default"), + param_type=p.get("param_type", "str"), + choices=p.get("choices"), + is_list=p.get("is_list", False), + ) + ) + client = AsyncMock() - client.list_scenarios_async.return_value = {"items": [{"scenario_name": "foo"}]} - client.get_scenario_async.return_value = { - "scenario_name": "foo", - "supported_parameters": supported_params or [], - } - client.start_scenario_run_async.return_value = {"scenario_result_id": "rid", "status": "CREATED"} - client.get_scenario_run_async.return_value = {"scenario_result_id": "rid", "status": status} - client.get_scenario_run_results_async.return_value = {"items": []} + client.list_scenarios_async.return_value = [ + RegisteredScenario( + scenario_name="foo", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + ) + ] + client.get_scenario_async.return_value = RegisteredScenario( + scenario_name="foo", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + supported_parameters=typed_params, + ) + client.start_scenario_run_async.return_value = ScenarioRunSummary( + scenario_result_id="rid", + scenario_name="foo", + scenario_version=0, + status=ScenarioRunState.CREATED, + created_at=now, + updated_at=now, + ) + client.get_scenario_run_async.return_value = ScenarioRunSummary( + scenario_result_id="rid", + scenario_name="foo", + scenario_version=0, + status=ScenarioRunState(status), + created_at=now, + updated_at=now, + ) + # Default: get_scenario_run_results_async raises so the summary fallback path runs. + client.get_scenario_run_results_async.side_effect = RuntimeError("results not configured") client.close_async = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=None) @@ -808,7 +935,7 @@ def test_scenario_declared_flag_is_forwarded(self, _mock_prog, _mock_print, mock assert result == 0 sent_request = client.start_scenario_run_async.call_args.kwargs["request"] - assert sent_request["scenario_params"] == {"max_turns": "7"} + assert sent_request.scenario_params == {"max_turns": "7"} @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) @patch("pyrit.cli.api_client.PyRITApiClient") @@ -835,7 +962,7 @@ def test_no_scenario_params_passes_through_cleanly(self, _mock_prog, _mock_print assert result == 0 sent_request = client.start_scenario_run_async.call_args.kwargs["request"] - assert "scenario_params" not in sent_request + assert sent_request.scenario_params is None def test_parse_args_tolerates_scenario_specific_flags(self): # Pass 1 must not error on scenario-declared flags (they're recognized in pass 2). diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 788ed71ad5..0cc7ef1524 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -15,18 +15,46 @@ @pytest.fixture() def mock_api_client(): - """Create a mock PyRITApiClient with default responses.""" + """Create a mock PyRITApiClient with default responses (typed wire-data).""" + from datetime import datetime, timezone + + from pyrit.models.catalog import RegisteredScenario + client = AsyncMock() client.health_check_async.return_value = True - client.list_scenarios_async.return_value = {"items": [], "pagination": {"total": 0}} - client.list_initializers_async.return_value = {"items": [], "pagination": {"total": 0}} - client.list_targets_async.return_value = {"items": [], "pagination": {"total": 0}} - client.list_scenario_runs_async.return_value = {"items": []} - # Default: scenario fetch returns no declared params (back-compat for older tests) - client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + client.list_scenarios_async.return_value = [] + client.list_initializers_async.return_value = [] + client.list_targets_async.return_value = [] + client.list_scenario_runs_async.return_value = [] + # Default: scenario fetch returns a typed RegisteredScenario with no declared params. + client.get_scenario_async.return_value = RegisteredScenario( + scenario_name="foo", + scenario_type="X", + description="", + default_strategy="", + aggregate_strategies=[], + all_strategies=[], + default_datasets=[], + max_dataset_size=None, + supported_parameters=[], + ) client.close_async = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=None) + # Helpers for tests to override the default scenario metadata. + client._make_typed_scenario = lambda **kw: RegisteredScenario( + scenario_name=kw.get("scenario_name", "foo"), + scenario_type=kw.get("scenario_type", "X"), + description=kw.get("description", ""), + default_strategy=kw.get("default_strategy", ""), + aggregate_strategies=kw.get("aggregate_strategies", []), + all_strategies=kw.get("all_strategies", []), + default_datasets=kw.get("default_datasets", []), + max_dataset_size=kw.get("max_dataset_size", None), + supported_parameters=kw.get("supported_parameters", []), + ) + # Suppress unused-import warning for datetime/timezone helpers used by tests. + _ = (datetime, timezone) return client @@ -101,13 +129,13 @@ def test_do_run_empty_args(self, shell, capsys): def test_do_scenario_history_default_limit(self, shell): s, client = shell - client.list_scenario_runs_async.return_value = {"items": []} + client.list_scenario_runs_async.return_value = [] s.do_scenario_history("") client.list_scenario_runs_async.assert_awaited_once_with(limit=10) def test_do_scenario_history_accepts_numeric_limit(self, shell): s, client = shell - client.list_scenario_runs_async.return_value = {"items": []} + client.list_scenario_runs_async.return_value = [] s.do_scenario_history("3") client.list_scenario_runs_async.assert_awaited_once_with(limit=3) @@ -378,8 +406,37 @@ def test_generic_error(self, shell, tmp_path, capsys): class TestDoRun: - def _run_payload(self, status="COMPLETED"): - return {"scenario_result_id": "rid-1", "status": status} + @staticmethod + def _run_payload(status="COMPLETED"): + """Build a typed ScenarioRunSummary for use as a mock return value.""" + from datetime import datetime, timezone + + from pyrit.models import ScenarioRunState + from pyrit.models.catalog import ScenarioRunSummary + + now = datetime(2025, 1, 1, tzinfo=timezone.utc) + return ScenarioRunSummary( + scenario_result_id="rid-1", + scenario_name="foo", + scenario_version=0, + status=ScenarioRunState(status), + created_at=now, + updated_at=now, + ) + + @staticmethod + def _empty_scenario_result(): + """Build a minimal ScenarioResult for use as get_scenario_run_results_async return.""" + from pyrit.models import ScenarioRunState + from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + + return ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="foo"), + objective_target_identifier=None, + objective_scorer_identifier=None, + attack_results={}, + scenario_run_state=ScenarioRunState.COMPLETED, + ) def test_run_invalid_arguments(self, shell, capsys): s, _ = shell @@ -401,7 +458,7 @@ def test_run_completed_path_with_results(self, shell, capsys): s, client = shell client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) client.get_scenario_run_async = AsyncMock(return_value=self._run_payload("COMPLETED")) - client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + client.get_scenario_run_results_async = AsyncMock(return_value=self._empty_scenario_result()) with ( patch( "pyrit.cli._cli_args.parse_run_arguments", @@ -423,15 +480,15 @@ def test_run_completed_path_with_results(self, shell, capsys): patch("time.sleep"), ): s.do_run("foo --target t") - kwargs = client.start_scenario_run_async.call_args.kwargs["request"] - assert kwargs["initializers"] == ["a", "b"] - assert kwargs["initializer_args"] == {"b": {"x": 1}} - assert kwargs["strategies"] == ["s1"] - assert kwargs["max_concurrency"] == 2 - assert kwargs["max_retries"] == 3 - assert kwargs["labels"] == {"k": "v"} - assert kwargs["dataset_names"] == ["d1"] - assert kwargs["max_dataset_size"] == 5 + sent = client.start_scenario_run_async.call_args.kwargs["request"] + assert sent.initializers == ["a", "b"] + assert sent.initializer_args == {"b": {"x": 1}} + assert sent.strategies == ["s1"] + assert sent.max_concurrency == 2 + assert sent.max_retries == 3 + assert sent.labels == {"k": "v"} + assert sent.dataset_names == ["d1"] + assert sent.max_dataset_size == 5 def test_run_failed_status_calls_summary(self, shell): s, client = shell @@ -531,8 +588,18 @@ def test_scenario_history_error(self, shell, capsys): class TestPrintScenarioAndHelp: def test_print_scenario_success(self, shell): - s, client = shell - client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + from pyrit.models import ScenarioRunState + from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + + s, client = shell + empty_result = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="foo"), + objective_target_identifier=None, + objective_scorer_identifier=None, + attack_results={}, + scenario_run_state=ScenarioRunState.COMPLETED, + ) + client.get_scenario_run_results_async = AsyncMock(return_value=empty_result) with patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) as mock_print: s.do_print_scenario("rid-1") mock_print.assert_awaited_once() @@ -683,14 +750,15 @@ class TestShellScenarioParamFlow: """Regression tests: shell.do_run must forward scenario-declared parameters.""" def test_run_passes_scenario_declared_params(self, shell): + from pyrit.models.catalog import ScenarioParameterSummary + s, client = shell - client.get_scenario_async.return_value = { - "scenario_name": "foo", - "supported_parameters": [{"name": "max_turns", "description": "..."}], - } - client.start_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "CREATED"}) - client.get_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "COMPLETED"}) - client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + client.get_scenario_async.return_value = client._make_typed_scenario( + supported_parameters=[ScenarioParameterSummary(name="max_turns", description="...", param_type="str")], + ) + client.start_scenario_run_async = AsyncMock(return_value=TestDoRun._run_payload("CREATED")) + client.get_scenario_run_async = AsyncMock(return_value=TestDoRun._run_payload("COMPLETED")) + client.get_scenario_run_results_async = AsyncMock(return_value=TestDoRun._empty_scenario_result()) with ( patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), @@ -700,7 +768,7 @@ def test_run_passes_scenario_declared_params(self, shell): s.do_run("foo --target t --max-turns 7") sent_request = client.start_scenario_run_async.call_args.kwargs["request"] - assert sent_request["scenario_params"] == {"max_turns": "7"} + assert sent_request.scenario_params == {"max_turns": "7"} def test_run_metadata_fetch_failure_aborts(self, shell, capsys): s, client = shell @@ -715,11 +783,12 @@ def test_run_unknown_scenario_aborts(self, shell, capsys): assert "not found on server" in capsys.readouterr().out def test_run_unknown_flag_for_scenario_with_declared_params_errors(self, shell, capsys): + from pyrit.models.catalog import ScenarioParameterSummary + s, client = shell - client.get_scenario_async.return_value = { - "scenario_name": "foo", - "supported_parameters": [{"name": "max_turns", "description": "..."}], - } + client.get_scenario_async.return_value = client._make_typed_scenario( + supported_parameters=[ScenarioParameterSummary(name="max_turns", description="...", param_type="str")], + ) s.do_run("foo --target t --not-a-real-flag x") captured = capsys.readouterr().out assert "Unknown argument" in captured or "Error" in captured @@ -727,7 +796,7 @@ def test_run_unknown_flag_for_scenario_with_declared_params_errors(self, shell, def test_run_fat_fingered_flag_with_no_scenario_params_errors(self, shell, capsys): """Even when the scenario declares no params, unknown flags must error (no silent no-op).""" s, client = shell - client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + client.get_scenario_async.return_value = client._make_typed_scenario(supported_parameters=[]) s.do_run("foo --target t --initialization-scripts /nope.py") captured = capsys.readouterr().out assert "Unknown argument: --initialization-scripts" in captured @@ -736,7 +805,7 @@ def test_run_fat_fingered_flag_with_no_scenario_params_errors(self, shell, capsy def test_run_fat_fingered_log_level_flag_errors(self, shell, capsys): """--log-level was a stale shell-only flag; passing it must now error.""" s, client = shell - client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + client.get_scenario_async.return_value = client._make_typed_scenario(supported_parameters=[]) s.do_run("foo --target t --log-level DEBUG") captured = capsys.readouterr().out assert "Unknown argument: --log-level" in captured @@ -747,16 +816,17 @@ class TestScenarioParamCoercionInShell: """Shell-side regression tests for typed scenario params from the catalog.""" def test_shell_list_param_collects_multiple_values(self, shell): + from pyrit.models.catalog import ScenarioParameterSummary + s, client = shell - client.get_scenario_async.return_value = { - "scenario_name": "foo", - "supported_parameters": [ - {"name": "items", "description": "list field", "param_type": "list[str]", "is_list": True} + client.get_scenario_async.return_value = client._make_typed_scenario( + supported_parameters=[ + ScenarioParameterSummary(name="items", description="list field", param_type="list[str]", is_list=True) ], - } - client.start_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "CREATED"}) - client.get_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "COMPLETED"}) - client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + ) + client.start_scenario_run_async = AsyncMock(return_value=TestDoRun._run_payload("CREATED")) + client.get_scenario_run_async = AsyncMock(return_value=TestDoRun._run_payload("COMPLETED")) + client.get_scenario_run_results_async = AsyncMock(return_value=TestDoRun._empty_scenario_result()) with ( patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), @@ -766,16 +836,17 @@ def test_shell_list_param_collects_multiple_values(self, shell): s.do_run("foo --target t --items a b c") sent = client.start_scenario_run_async.call_args.kwargs["request"] - assert sent["scenario_params"] == {"items": ["a", "b", "c"]} + assert sent.scenario_params == {"items": ["a", "b", "c"]} def test_shell_choices_rejected_before_request(self, shell, capsys): + from pyrit.models.catalog import ScenarioParameterSummary + s, client = shell - client.get_scenario_async.return_value = { - "scenario_name": "foo", - "supported_parameters": [ - {"name": "mode", "description": "...", "param_type": "str", "choices": ["fast", "slow"]} + client.get_scenario_async.return_value = client._make_typed_scenario( + supported_parameters=[ + ScenarioParameterSummary(name="mode", description="...", param_type="str", choices=["fast", "slow"]) ], - } + ) s.do_run("foo --target t --mode warp") out = capsys.readouterr().out # Parameter.coerce_value raises ValueError on out-of-choice values;