diff --git a/doc/gui/0_gui.md b/doc/gui/0_gui.md index e93cbb83b8..e203557859 100644 --- a/doc/gui/0_gui.md +++ b/doc/gui/0_gui.md @@ -136,7 +136,21 @@ The Configuration view manages the targets available for attacks. #### Target Table -Lists all registered targets with their type, endpoint, and model name. Click "Set Active" to select a target for use in the Chat view. The active target is highlighted with an "Active" badge. +Lists all registered targets with their type, endpoint, and model name. Click "Set Active" to select a target for use in the Chat view. The active target is highlighted with an "Active" badge. The **Validate** column (with a beaker icon button) lets you probe a target's live capabilities and see a declared-vs-observed diff (see [Validating Targets](#validating-targets) below). + +#### Validating Targets + +The **Validate** column in the target table has a beaker icon button on every top-level row that runs PyRIT's `discover_target_capabilities_async` engine against the selected target and opens a modal showing declared-vs-observed capability flags and input modalities. The button is placed next to the capability columns (Inputs, Outputs, Multi-turn, …) so it sits with the data it inspects. Use this when you want to confirm that a target actually accepts the request shapes its class declares (for example, when an Azure OpenAI gateway strips a feature, or when a multimodal class is pointed at a text-only deployment) before launching a long attack run. + +The dialog: + +- Sends real requests to the target — this may incur cost and produce side effects (logs, billing, content-policy hits). Test prompts are written to memory tagged `capability_probe`. +- Caps per-probe timeout at 5 seconds for GUI responsiveness. +- Reports output modalities as declared (those are not actively probed) and renders an amber em-dash for them. +- Reports declared input-modality combinations the engine has no packaged test asset for (e.g., `function_call`, `tool_call`, `reasoning`, `url`) in a separate "Not probed (no asset)" row rather than as false red mismatches. +- Should NOT be run while an attack or scenario is actively using the same target — validation temporarily changes the target's runtime configuration during probing. + +Only top-level registered targets have a Validate button; inner targets of composite wrappers (e.g., `RoundRobinTarget` children) are reachable only through the wrapper. #### Creating Targets diff --git a/frontend/src/components/Config/TargetTable.styles.ts b/frontend/src/components/Config/TargetTable.styles.ts index 23ce7d25c6..c17fc1a258 100644 --- a/frontend/src/components/Config/TargetTable.styles.ts +++ b/frontend/src/components/Config/TargetTable.styles.ts @@ -38,6 +38,10 @@ export const useTargetTableStyles = makeStyles({ width: '160px', textAlign: 'center', }, + validateCell: { + width: '90px', + textAlign: 'center', + }, modalityRow: { display: 'inline-flex', alignItems: 'center', diff --git a/frontend/src/components/Config/TargetTable.test.tsx b/frontend/src/components/Config/TargetTable.test.tsx index 17408d6d3c..7f7574a59b 100644 --- a/frontend/src/components/Config/TargetTable.test.tsx +++ b/frontend/src/components/Config/TargetTable.test.tsx @@ -1,12 +1,21 @@ -import { render, screen, fireEvent } from '@testing-library/react' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' import { FluentProvider, webLightTheme } from '@fluentui/react-components' import TargetTable from './TargetTable' import type { TargetInstance } from '../../types' +import { targetsApi } from '@/services/api' jest.mock('./TargetTable.styles', () => ({ useTargetTableStyles: () => new Proxy({}, { get: () => '' }), })) +jest.mock('@/services/api', () => ({ + targetsApi: { + validateCapabilities: jest.fn(), + }, +})) + +const mockedApi = targetsApi as jest.Mocked + const TestWrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => ( {children} ) @@ -397,4 +406,97 @@ describe('TargetTable', () => { expect(screen.queryByLabelText('Expand inner targets')).not.toBeInTheDocument() }) + + // --- F5: Validate button wiring --- + + it('renders a Validate button on every top-level row', () => { + render( + + + , + ) + const validateButtons = screen.getAllByRole('button', { name: /^Validate capabilities for / }) + // 3 sample targets, 1 button each (no active target → no extra active-row button) + expect(validateButtons).toHaveLength(3) + }) + + it('also renders a Validate button on the active-target summary row', () => { + render( + + + , + ) + const validateButtons = screen.getAllByRole('button', { name: /^Validate capabilities for / }) + // 3 list rows + 1 active-row summary = 4 + expect(validateButtons).toHaveLength(4) + }) + + it('does NOT render Validate buttons on inner-target rows (composite expansion)', () => { + const rrTarget: TargetInstance = { + target_registry_name: 'rr_gpt4o', + target_type: 'RoundRobinTarget', + model_name: 'gpt-4o', + target_specific_params: { weights: [1, 1] }, + inner_targets: [ + { + target_registry_name: 'inner_a', + target_type: 'OpenAIChatTarget', + endpoint: 'https://a.openai.azure.com', + model_name: 'gpt-4o', + }, + { + target_registry_name: 'inner_b', + target_type: 'OpenAIChatTarget', + endpoint: 'https://b.openai.azure.com', + model_name: 'gpt-4o', + }, + ], + } + render( + + + , + ) + // Before expanding: 1 top-level row → 1 Validate button + expect(screen.getAllByRole('button', { name: /^Validate capabilities for / })).toHaveLength(1) + // Expand + fireEvent.click(screen.getByLabelText('Expand inner targets')) + expect(screen.getByText('https://a.openai.azure.com')).toBeInTheDocument() + // After expanding: still only 1 Validate button (inner rows don't get one) + expect(screen.getAllByRole('button', { name: /^Validate capabilities for / })).toHaveLength(1) + }) + + it('opens the validation dialog when a Validate button is clicked', async () => { + // Pending promise so the dialog stays in the loading state we can detect. + mockedApi.validateCapabilities.mockReturnValue(new Promise(() => {})) + render( + + + , + ) + const validateButtons = screen.getAllByRole('button', { name: /^Validate capabilities for / }) + fireEvent.click(validateButtons[0]) + await waitFor(() => { + expect(mockedApi.validateCapabilities).toHaveBeenCalledWith('openai_chat_gpt4') + }) + expect(screen.getByText(/Validate capabilities: openai_chat_gpt4/i)).toBeInTheDocument() + }) + + it('disables the Validate button for the row whose dialog is currently open', async () => { + mockedApi.validateCapabilities.mockReturnValue(new Promise(() => {})) + render( + + + , + ) + const validateButtons = screen.getAllByRole('button', { name: /^Validate capabilities for / }) + fireEvent.click(validateButtons[0]) + await waitFor(() => { + // The first row's Validate button is now disabled. + const stillButtons = screen.getAllByRole('button', { name: /^Validate capabilities for / }) + expect(stillButtons[0]).toBeDisabled() + // The other rows' buttons remain enabled. + expect(stillButtons[1]).not.toBeDisabled() + }) + }) }) diff --git a/frontend/src/components/Config/TargetTable.tsx b/frontend/src/components/Config/TargetTable.tsx index 79d1df3777..0a7e0b6fa1 100644 --- a/frontend/src/components/Config/TargetTable.tsx +++ b/frontend/src/components/Config/TargetTable.tsx @@ -28,9 +28,11 @@ import { ArrowHookUpLeftRegular, ChevronRightRegular, ChevronDownRegular, + BeakerRegular, } from '@fluentui/react-icons' import type { TargetInstance } from '../../types' import { useTargetTableStyles } from './TargetTable.styles' +import ValidateCapabilitiesDialog from './ValidateCapabilitiesDialog' interface TargetTableProps { targets: TargetInstance[] @@ -72,6 +74,7 @@ const COLUMN_TOOLTIPS = { parameters: 'Target-specific configuration parameters (e.g., reasoning_effort, max_output_tokens)', inputs: 'Modalities the target accepts as input', outputs: 'Modalities the target can produce as output', + validate: 'Probe the target live and compare observed capabilities to the declared values shown in this row', } as const /** Composite icon: f(x) with a small return-arrow badge for function call outputs. */ @@ -244,6 +247,10 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } // We use a Set of target_registry_name strings — when a name is in the set, // that row's sub-rows are visible. const [expandedRows, setExpandedRows] = useState>(new Set()) + // The target whose Validate dialog is currently open, or null. + // Inner-target rows (composite expansion) do NOT get a Validate button — + // they aren't registered by name in the backend TargetRegistry. + const [validateTarget, setValidateTarget] = useState(null) const toggleExpanded = (registryName: string) => { setExpandedRows((prev) => { @@ -310,6 +317,18 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } + + + + + + + + ) +} diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index 3c04828cb0..320a9a2a97 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -20,6 +20,7 @@ import type { CreateConversationRequest, CreateConversationResponse, ChangeMainConversationResponse, + ValidateCapabilitiesResponse, } from '../types' const API_BASE_URL = import.meta.env.VITE_API_URL || '/api' @@ -162,6 +163,18 @@ export const targetsApi = { const response = await apiClient.post('/targets', request) return response.data }, + + validateCapabilities: async ( + targetRegistryName: string, + ): Promise => { + // POST is appropriate here: the call sends live requests to the target + // and writes probe rows to memory (side effects), even though the response + // shape is a read-only diff. + const response = await apiClient.post( + `/targets/${encodeURIComponent(targetRegistryName)}/validate`, + ) + return response.data + }, } export const convertersApi = { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1c6dcc283e..8d4a0af6ee 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -98,6 +98,34 @@ export interface CreateTargetRequest { auth_mode?: 'api_key' | 'entra' } +export interface ValidateCapabilitiesResponse { + target_registry_name: string + declared: TargetCapabilitiesInfo + observed: TargetCapabilitiesInfo + /** + * Sorted '+'-joined declared input-modality combinations that the engine + * could not probe because no packaged test asset exists (e.g., + * 'function_call', 'image_path+url'). Used by ValidateCapabilitiesDialog to + * render a single "Not probed (no asset)" row beneath the input-modalities + * row, distinguishing "not probed" from "probed and confirmed". + */ + non_probeable_input_modalities: string[] + /** + * Sorted list of declared input-modality types that appear ONLY in + * non-probeable combinations (never in any probeable combination). Used + * by ValidateCapabilitiesDialog to filter the input-modality cells without + * accidentally hiding types confirmed via a probeable singleton combo — + * e.g., for a target declaring both `{text}` and `{text, function_call}`, + * this list contains only `function_call`, leaving `text` visible. + */ + non_probeable_only_types: string[] + /** + * Operational notes for the user (live-call cost, memory tagging, output + * modalities not probed, semantic-enforcement caveat, validate-vs-active-attack). + */ + warnings: string[] +} + // --- Converters --- export interface ConverterInstance { diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py index 310b04e916..a6d169fc9f 100644 --- a/pyrit/backend/mappers/__init__.py +++ b/pyrit/backend/mappers/__init__.py @@ -20,6 +20,7 @@ converter_object_to_instance, ) from pyrit.backend.mappers.target_mappers import ( + target_capabilities_to_info, target_object_to_instance, ) @@ -31,5 +32,6 @@ "pyrit_scores_to_dto", "request_piece_to_pyrit_message_piece", "request_to_pyrit_message", + "target_capabilities_to_info", "target_object_to_instance", ] diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index b7f715aca0..6f5ced845e 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -15,7 +15,7 @@ _CAPABILITY_PARAM_NAMES = frozenset(cap.value for cap in CapabilityName) -def _target_capabilities_to_info(capabilities: TargetCapabilities) -> TargetCapabilitiesInfo: +def target_capabilities_to_info(capabilities: TargetCapabilities) -> TargetCapabilitiesInfo: """ Build a TargetCapabilitiesInfo DTO from a domain TargetCapabilities object. @@ -102,7 +102,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge temperature=params.get("temperature"), top_p=params.get("top_p"), max_requests_per_minute=params.get("max_requests_per_minute"), - capabilities=_target_capabilities_to_info(target_obj.capabilities), + capabilities=target_capabilities_to_info(target_obj.capabilities), target_specific_params=combined_specific, inner_targets=inner_targets, identifier_hash=identifier.hash, diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 388076fcd5..08792024be 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -63,6 +63,7 @@ TargetCapabilitiesInfo, TargetInstance, TargetListResponse, + ValidateCapabilitiesResponse, ) __all__ = [ @@ -117,4 +118,5 @@ "TargetCapabilitiesInfo", "TargetInstance", "TargetListResponse", + "ValidateCapabilitiesResponse", ] diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index 9e5ce0c56d..e59bd90b71 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -77,6 +77,61 @@ class TargetListResponse(BaseModel): pagination: PaginationInfo = Field(..., description="Pagination metadata") +class ValidateCapabilitiesResponse(BaseModel): + """ + Response from validating a target's declared capabilities against observed behavior. + + Surfaces what the target class declares versus what live probing observed, + so users can spot drift caused by gateways stripping features, model + deployments lacking capabilities, or misconfiguration. + """ + + target_registry_name: str = Field(..., description="Target registry key the validation ran against") + declared: TargetCapabilitiesInfo = Field(..., description="Capabilities as declared by the target class") + observed: TargetCapabilitiesInfo = Field(..., description="Capabilities as observed by live probing") + # Drives the frontend "Not probed (no asset)" row beneath the input-modalities + # row. Without this field, the engine's `queried | (declared - test_modalities)` + # math at discover_target_capabilities.py:778 ORs non-probeable combinations + # back into observed, making observed == declared, and the frontend has no way + # to distinguish "genuinely confirmed" from "not probed". + non_probeable_input_modalities: list[str] = Field( + default_factory=list, + description=( + "Sorted list of declared input-modality combinations that could NOT be probed " + "because the engine has no packaged test asset for the contained types. Each " + "entry is a '+'-joined sorted combination (e.g., 'function_call' or 'image_path+url'). " + "The frontend renders the union of these as a single 'Not probed (no asset)' row " + "beneath the input-modalities row." + ), + ) + # Distinct from ``non_probeable_input_modalities`` (which carries the + # combo display strings). When a target declares both a probeable combo + # like ``{text}`` and a non-probeable mixed combo like ``{text, + # function_call}``, splitting the combo string on '+' and stripping every + # piece from the input-modality cells would incorrectly hide ``text`` — + # which *was* probed and confirmed via the singleton combo. This field + # lists only the types that never appear in any probeable combo, so the + # frontend can safely filter cells without dropping confirmed modalities. + non_probeable_only_types: list[str] = Field( + default_factory=list, + description=( + "Sorted list of declared input modality types that appear ONLY in non-probeable " + "combinations (never in any probeable combination). The frontend uses this set to " + "hide truly unprobed types from the input-modality cells while leaving types that " + "were confirmed via a probeable singleton combo visible. Disjoint from the types " + "implicit in ``observed.supported_input_modalities`` that came from a probeable probe." + ), + ) + warnings: list[str] = Field( + default_factory=list, + description=( + "Operational notes for the user (e.g., 'this validation wrote test prompts to memory', " + "'output modalities are not probed and fall through to declared values', " + "'do not validate while an attack is actively running against this target')." + ), + ) + + class CreateTargetRequest(BaseModel): """Request to create a new target instance.""" diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index bea53ddef2..23760f4639 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -15,6 +15,7 @@ CreateTargetRequest, TargetInstance, TargetListResponse, + ValidateCapabilitiesResponse, ) from pyrit.backend.services.target_service import get_target_service @@ -104,3 +105,50 @@ async def get_target(target_registry_name: str) -> TargetInstance: # pyrit-asyn ) return target + + +@router.post( + "/{target_registry_name}/validate", + response_model=ValidateCapabilitiesResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Target not found"}, + 500: {"model": ProblemDetail, "description": "Validation failed"}, + }, +) +async def validate_target_capabilities( # pyrit-async-suffix-exempt + target_registry_name: str, +) -> ValidateCapabilitiesResponse: + """ + Validate a target by probing its live capabilities against declarations. + + The probe sends a small set of test requests to the target and reports + the declared vs observed capability flags and input modalities. Output + modalities are reported as declared (not actively probed). Test prompts + are written to memory. + + Returns: + ValidateCapabilitiesResponse: Declared and observed capabilities, plus + a list of declared input-modality combinations that could not be + probed because no test asset is packaged for them, plus operational + warnings (live-call cost, memory tagging, semantic-enforcement caveat, + validate-vs-active-attack caveat). + """ + service = get_target_service() + + try: + result = await service.validate_target_capabilities_async( + target_registry_name=target_registry_name, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to validate target: {str(e)}", + ) from e + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Target '{target_registry_name}' not found", + ) + + return result diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 6663dfa57b..8a1eb2885d 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -12,6 +12,7 @@ - Retrieved from registry (pre-registered at startup or created earlier) """ +import asyncio import logging import os from functools import lru_cache @@ -20,14 +21,16 @@ from pyrit import prompt_target from pyrit.auth import get_azure_async_token_provider, get_azure_openai_auth -from pyrit.backend.mappers.target_mappers import target_object_to_instance +from pyrit.backend.mappers.target_mappers import target_capabilities_to_info, target_object_to_instance from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.targets import ( CreateTargetRequest, TargetInstance, TargetListResponse, + ValidateCapabilitiesResponse, ) -from pyrit.prompt_target import PromptTarget +from pyrit.models import PromptDataType +from pyrit.prompt_target import PromptTarget, discover_target_capabilities_async from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.openai.openai_target import OpenAITarget from pyrit.prompt_target.round_robin_target import RoundRobinTarget @@ -35,6 +38,19 @@ logger = logging.getLogger(__name__) +# Module-level allowlist of input modalities that the discovery engine actually +# has probe assets for. See `discover_target_capabilities.py:DEFAULT_TEST_ASSETS` +# (image_path, audio_path) and `_create_test_message` (text is synthetic). +# +# Any combination containing a modality outside this set will raise ValueError +# inside the engine, be silently skipped, and — when test_modalities is set +# explicitly — be DROPPED from the resolved input_modalities +# (`queried | (declared - frozenset(test_modalities))`). Filtering here prevents +# false red mismatches against real targets like OpenAIResponseTarget +# (declares function_call, tool_call, reasoning) and AzureBlobStorageTarget +# (declares url). +_PROBEABLE_INPUT_MODALITIES: frozenset[PromptDataType] = frozenset({"text", "image_path", "audio_path"}) + # Recognised Azure OpenAI / AI Foundry hostname suffixes. Used for strict # endpoint validation when Entra ID auth is requested, so a bearer token is # only ever issued for a known Microsoft-operated endpoint. @@ -139,9 +155,53 @@ class TargetService: # Scope for Azure Machine Learning managed online endpoints. _AZURE_ML_SCOPE: ClassVar[str] = "https://ml.azure.com/.default" + # Per-probe timeout for the GUI validation flow. The engine default is + # 30 s, which compounded across 5+ probes can exceed 2 min; 5 s keeps + # the GUI snappy while still catching real rejections. + _GUI_VALIDATE_TIMEOUT_S: ClassVar[float] = 15.0 + def __init__(self) -> None: """Initialize the target service.""" self._registry = TargetRegistry.get_registry_singleton() + # Per-target asyncio locks for capability validation. The discovery + # engine mutates `target._configuration` in place + # (discover_target_capabilities.py:_permissive_configuration) and the + # registry returns a singleton instance, so two concurrent validations + # on the same target can race on the restore. The lock dict is an + # INSTANCE attribute (not ClassVar): an asyncio.Lock lazy-binds to + # the running event loop on first await, and pytest gives each test a + # fresh event loop (pyproject.toml: asyncio_default_fixture_loop_scope + # = "function"). A ClassVar dict would leak locks from one test's + # loop into the next and raise RuntimeError. Instance dict = fresh + # per TargetService() = matches existing per-test pattern. + # Lock-map cardinality is bounded by registry size (one entry per + # registered target), not by call volume, so no eviction is needed + # for typical PyRIT workloads. + self._validate_locks: dict[str, asyncio.Lock] = {} + + def _get_validate_lock(self, *, target_registry_name: str) -> asyncio.Lock: + """ + Get-or-create the per-target validation lock. + + Kept synchronous on purpose: there is no ``await`` between the dict + ``get`` and the assignment, so two coroutines cannot interleave + between them and no extra guard lock is needed. Staying sync also + sidesteps the ``check-async-suffix`` hook (no ``_async`` suffix + needed for non-async methods). The returned ``asyncio.Lock`` binds + lazily to the running event loop on the caller's first + ``await lock.acquire()``. + + Args: + target_registry_name: The registry key of the target whose lock to fetch. + + Returns: + The per-target ``asyncio.Lock`` (created on first access). + """ + lock = self._validate_locks.get(target_registry_name) + if lock is None: + lock = asyncio.Lock() + self._validate_locks[target_registry_name] = lock + return lock def _get_target_class(self, *, target_type: str) -> type: """ @@ -241,6 +301,129 @@ def get_target_object(self, *, target_registry_name: str) -> Any | None: """ return self._registry.get_instance_by_name(target_registry_name) + async def validate_target_capabilities_async( + self, + *, + target_registry_name: str, + per_probe_timeout_s: float | None = None, + ) -> ValidateCapabilitiesResponse | None: + """ + Probe a target's live capabilities and return both declared and observed views. + + The probe writes test prompts to memory (existing behavior of the + discovery engine). Output modalities are not probed and fall through + to declared values. Probeable input modalities (text, image_path, + audio_path) listed in the target's declared capabilities are probed + explicitly so that rejections surface as drift; non-probeable + declared modalities (function_call, tool_call, reasoning, url, + video_path, binary_path, etc.) are reported as declared without + being probed and listed in ``non_probeable_input_modalities`` so + the frontend can render a single "Not probed (no asset)" row. + + Args: + target_registry_name: The registry key of the target to validate. + per_probe_timeout_s: Per-probe timeout in seconds. Defaults to + ``_GUI_VALIDATE_TIMEOUT_S`` (15.0) for interactive use. + + Returns: + ValidateCapabilitiesResponse, or None if the target is not in the registry. + """ + timeout_s = per_probe_timeout_s if per_probe_timeout_s is not None else self._GUI_VALIDATE_TIMEOUT_S + + target_obj = self.get_target_object(target_registry_name=target_registry_name) + if target_obj is None: + return None + + declared = target_capabilities_to_info(target_obj.capabilities) + + # CRITICAL: pass only the *probeable* declared modality combinations as + # ``test_modalities``. Without this filter, a combination like + # ``frozenset(["function_call"])`` raises ValueError inside + # ``_create_test_message`` (engine: discover_target_capabilities.py), + # the combo is silently skipped, and the result line + # ``queried | (declared - frozenset(test_modalities))`` drops it — + # producing a false red mismatch in the UI. The non-probeable combos + # are surfaced via ``non_probeable_input_modalities`` so the frontend + # can render them as "Not probed (no asset)" rather than mismatched. + declared_combinations: set[frozenset[PromptDataType]] = set(target_obj.capabilities.input_modalities) + probeable_combinations: set[frozenset[PromptDataType]] = { + combo for combo in declared_combinations if combo <= _PROBEABLE_INPUT_MODALITIES + } + non_probeable: set[frozenset[PromptDataType]] = declared_combinations - probeable_combinations + + # Per-target lock guards against the ``target._configuration`` race + # documented above. Helper is sync (see its docstring). Pass + # ``test_modalities=probeable_combinations`` even when empty: the engine + # short-circuits cleanly on empty set (logs "nothing to probe", returns + # empty) before entering ``_permissive_configuration``, avoiding an + # unnecessary configuration mutate+restore round-trip. Passing ``None`` + # would default the engine to all declared modalities and trigger + # ValueError-and-skip-log noise on every non-probeable combo. + lock = self._get_validate_lock(target_registry_name=target_registry_name) + async with lock: + observed_domain = await discover_target_capabilities_async( + target=target_obj, + per_probe_timeout_s=timeout_s, + test_modalities=probeable_combinations, + apply=False, + # retries left at the engine default (1) so cold-start targets + # don't false-negative; worst-case wait per probe is ~10 s. + ) + observed = target_capabilities_to_info(observed_domain) + + warnings = [ + ( + "Validation sent live requests to the target; this may incur cost " + "and produce real side effects (logs, billing, content policy hits)." + ), + "Test prompts written to memory are tagged with `capability_probe`.", + "Output modalities are reported as declared (not actively probed).", + ( + "Capability probes confirm request acceptance, not semantic enforcement " + "(e.g., a target that accepts a JSON-schema request may not actually " + "enforce the schema). Image probes can also currently false-negative on " + "some targets due to known probe-asset and payload-format issues; re-run " + "or verify manually before relying on a red image result." + ), + # The per-target lock above serializes Validate-vs-Validate but NOT + # Validate-vs-attack: the engine briefly mutates target._configuration, + # so an active attack on this target during validation may briefly + # observe permissive probe config. Surface as a user-visible warning. + ( + "Do not run Validate while an attack or scenario is actively using " + "this target — validation temporarily changes target configuration " + "during probing." + ), + ] + + # Format non-probeable combinations as a stable, '+'-joined sorted list + # so the frontend gets a typed, explicit signal (no warning-string parsing). + non_probeable_combos_pretty: list[str] = sorted("+".join(sorted(combo)) for combo in non_probeable) + if non_probeable_combos_pretty: + warnings.append( + "Some declared input modalities are reported as declared/not-probed " + f"(no packaged probe asset): {', '.join(non_probeable_combos_pretty)}." + ) + + # Types that appear ONLY in non-probeable combos (never in a probeable + # one). The frontend uses this for cell-filtering: if a type also + # belongs to some probeable combo it WAS confirmed, so the cell should + # still show it. Splitting the combo strings on '+' and using the union + # would incorrectly hide ``text`` for a target declaring both + # ``{text}`` and ``{text, function_call}``. + probeable_types: set[PromptDataType] = set().union(*probeable_combinations) if probeable_combinations else set() + non_probeable_types: set[PromptDataType] = set().union(*non_probeable) if non_probeable else set() + non_probeable_only_types: list[str] = sorted(non_probeable_types - probeable_types) + + return ValidateCapabilitiesResponse( + target_registry_name=target_registry_name, + declared=declared, + observed=observed, + non_probeable_input_modalities=non_probeable_combos_pretty, + non_probeable_only_types=non_probeable_only_types, + warnings=warnings, + ) + async def create_target_async(self, *, request: CreateTargetRequest) -> TargetInstance: """ Create a new target instance from API request. diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 59bf407382..f78914a8c1 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -38,6 +38,7 @@ TargetCapabilitiesInfo, TargetInstance, TargetListResponse, + ValidateCapabilitiesResponse, ) from pyrit.backend.routes.labels import get_label_options @@ -954,6 +955,66 @@ def test_get_target_includes_target_specific_params(self, client: TestClient) -> assert data["target_specific_params"]["presence_penalty"] == 0.3 assert data["target_specific_params"]["seed"] == 42 + def test_validate_target_returns_200_with_declared_and_observed(self, client: TestClient) -> None: + """Happy path: validate route returns 200 with full ValidateCapabilitiesResponse shape.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.validate_target_capabilities_async = AsyncMock( + return_value=ValidateCapabilitiesResponse( + target_registry_name="target-1", + declared=TargetCapabilitiesInfo( + supports_json_schema=True, + supported_input_modalities=["image_path", "text"], + ), + observed=TargetCapabilitiesInfo( + supports_json_schema=False, + supported_input_modalities=["text"], + ), + non_probeable_input_modalities=["function_call"], + non_probeable_only_types=["function_call"], + warnings=["Validation sent live requests to the target; ..."], + ) + ) + mock_get_service.return_value = mock_service + + response = client.post("/api/targets/target-1/validate") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["target_registry_name"] == "target-1" + assert data["declared"]["supports_json_schema"] is True + assert data["observed"]["supports_json_schema"] is False + assert data["non_probeable_input_modalities"] == ["function_call"] + assert data["non_probeable_only_types"] == ["function_call"] + assert isinstance(data["warnings"], list) and data["warnings"] + + def test_validate_target_returns_404_when_target_missing(self, client: TestClient) -> None: + """Unknown target → service returns None → 404 with FastAPI's default {'detail': ...} body.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.validate_target_capabilities_async = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.post("/api/targets/missing/validate") + + assert response.status_code == status.HTTP_404_NOT_FOUND + # Backend has no HTTPException → ProblemDetail handler; default shape applies. + assert response.json() == {"detail": "Target 'missing' not found"} + + def test_validate_target_returns_500_when_probe_fails(self, client: TestClient) -> None: + """Engine raises → 500 with default {'detail': 'Failed to validate target: ...'} body.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.validate_target_capabilities_async = AsyncMock(side_effect=RuntimeError("network blew up")) + mock_get_service.return_value = mock_service + + response = client.post("/api/targets/target-1/validate") + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + body = response.json() + assert "Failed to validate target" in body["detail"] + assert "network blew up" in body["detail"] + # ============================================================================ # Converter Routes Tests diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index b86467eb88..879186b093 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -823,3 +823,463 @@ def test_target_eval_param_fallbacks_match_frontend(self) -> None: f"Update effectiveUnderlyingModel() in CreateTargetDialog.tsx to match, " f"then update this test's expected dict." ) + + +# ============================================================================ +# Capability Validation Tests +# ============================================================================ + + +def _fake_target_with_capabilities( + *, + input_modalities: frozenset[frozenset[str]] | None = None, + supports_json_schema: bool = True, +) -> MagicMock: + """ + Build a MagicMock target whose ``capabilities`` attribute is a real + ``TargetCapabilities`` object. The discovery engine is mocked separately, + so we don't need a real PromptTarget subclass — just the ``.capabilities`` + attribute that ``validate_target_capabilities_async`` reads. + """ + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_json_schema=supports_json_schema, + supports_json_output=True, + supports_editable_history=True, + supports_system_prompt=True, + input_modalities=(input_modalities if input_modalities is not None else frozenset({frozenset(["text"])})), + ) + target = MagicMock() + target.capabilities = caps + return target + + +def _fake_observed_capabilities( + *, + declared, # TargetCapabilities + drop_input_modalities: set[frozenset[str]] | None = None, + flip_json_schema_to_false: bool = False, +): + """ + Build a fake "observed" TargetCapabilities for the mock engine to return. + + Mirrors what the real engine produces: starts from ``declared`` and + selectively drops/flips fields to simulate drift. + """ + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + + observed_input = declared.input_modalities + if drop_input_modalities: + observed_input = frozenset(c for c in observed_input if c not in drop_input_modalities) + return TargetCapabilities( + supports_multi_turn=declared.supports_multi_turn, + supports_multi_message_pieces=declared.supports_multi_message_pieces, + supports_json_schema=False if flip_json_schema_to_false else declared.supports_json_schema, + supports_json_output=declared.supports_json_output, + supports_editable_history=declared.supports_editable_history, + supports_system_prompt=declared.supports_system_prompt, + input_modalities=observed_input, + output_modalities=declared.output_modalities, + ) + + +class TestValidateTargetCapabilities: + """Tests for TargetService.validate_target_capabilities_async.""" + + async def test_returns_none_for_unknown_target(self) -> None: + """Unknown registry name returns None; engine is NOT called.""" + from unittest.mock import AsyncMock + + service = TargetService() + with patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe: + result = await service.validate_target_capabilities_async(target_registry_name="missing") + assert result is None + mock_probe.assert_not_called() + + async def test_returns_response_for_known_target(self) -> None: + """Happy path: declared + observed populated, warnings present, no non-probeable.""" + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities() + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + result = await service.validate_target_capabilities_async(target_registry_name="t1") + + assert result is not None + assert result.target_registry_name == "t1" + assert result.declared.supports_json_schema is True + assert result.observed.supports_json_schema is True + assert result.non_probeable_input_modalities == [] + # 5 base warnings, no 6th (no non-probeable) + assert len(result.warnings) == 5 + + async def test_passes_timeout_override(self) -> None: + """Caller-supplied per_probe_timeout_s reaches the discovery call.""" + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities() + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + await service.validate_target_capabilities_async(target_registry_name="t1", per_probe_timeout_s=10.0) + assert mock_probe.call_args.kwargs["per_probe_timeout_s"] == 10.0 + + async def test_uses_gui_default_timeout_when_not_overridden(self) -> None: + """When per_probe_timeout_s is None, the GUI default (5.0) is passed.""" + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities() + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + await service.validate_target_capabilities_async(target_registry_name="t1") + assert mock_probe.call_args.kwargs["per_probe_timeout_s"] == TargetService._GUI_VALIDATE_TIMEOUT_S + assert mock_probe.call_args.kwargs["per_probe_timeout_s"] == 15.0 + + async def test_passes_probeable_modalities_only(self) -> None: + """ + CRITICAL regression guard: only probeable modality combinations + reach the engine. Non-probeable combos appear in the response's + ``non_probeable_input_modalities`` list and in a warning. + """ + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities( + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["function_call"]), + frozenset(["url"]), + } + ) + ) + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + result = await service.validate_target_capabilities_async(target_registry_name="t1") + + # (a) only the two probeable combos reach the engine + passed = mock_probe.call_args.kwargs["test_modalities"] + assert passed == {frozenset(["text"]), frozenset(["text", "image_path"])} + + # (b) non-probeable types appear in the warnings list + assert result is not None + non_probed_warning = [w for w in result.warnings if "no packaged probe asset" in w] + assert len(non_probed_warning) == 1 + assert "function_call" in non_probed_warning[0] + assert "url" in non_probed_warning[0] + + # (c) typed field has the sorted, '+'-joined list + assert result.non_probeable_input_modalities == ["function_call", "url"] + + # (d) function_call and url appear ONLY in non-probeable combos + # (each in its own singleton, no probeable combo includes them). + assert result.non_probeable_only_types == ["function_call", "url"] + + async def test_non_probeable_only_types_excludes_types_confirmed_via_probeable_combo(self) -> None: + """ + Regression guard for the dialog cell-filter bug: when a target + declares both a probeable singleton like ``{text}`` AND a non-probeable + mixed combo like ``{text, function_call}``, ``text`` IS confirmed + (via the singleton) and must not appear in + ``non_probeable_only_types`` — otherwise the frontend would strip + ``text`` from the Input modalities cells and show ``— / —`` despite + it being probed and confirmed. + + ``non_probeable_input_modalities`` (the combo display list) still + contains the mixed combo so the "Not probed (no asset)" row can + surface it; the cell-filter logic uses the narrower + ``non_probeable_only_types`` set instead. + """ + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities( + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "function_call"]), + frozenset(["image_path"]), + frozenset(["image_path", "url"]), + } + ) + ) + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + result = await service.validate_target_capabilities_async(target_registry_name="t1") + + assert result is not None + # Combo display list keeps the mixed combos (used by the "Not probed" row). + assert result.non_probeable_input_modalities == ["function_call+text", "image_path+url"] + # Cell-filter list contains only types NOT confirmed by any probeable combo. + # `text` is confirmed by {text}; `image_path` is confirmed by {image_path}. + # Only `function_call` and `url` are exclusively non-probeable. + assert result.non_probeable_only_types == ["function_call", "url"] + + async def test_passes_empty_set_when_no_probeable_modalities(self) -> None: + """ + Declared modalities are all non-probeable. Method passes + ``test_modalities=set()`` (NOT None) so the engine short-circuits + cleanly without entering ``_permissive_configuration``. Warnings + still include the not-probed entry, and the typed field lists every + declared combo. + """ + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities( + input_modalities=frozenset( + { + frozenset(["function_call"]), + frozenset(["url"]), + frozenset(["video_path"]), + } + ) + ) + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + ) as mock_probe, + ): + mock_probe.return_value = observed + result = await service.validate_target_capabilities_async(target_registry_name="t1") + + passed = mock_probe.call_args.kwargs["test_modalities"] + assert passed == set() + assert isinstance(passed, set) + assert result is not None + assert result.non_probeable_input_modalities == ["function_call", "url", "video_path"] + # With no probeable combos, every declared type is exclusively non-probeable. + assert result.non_probeable_only_types == ["function_call", "url", "video_path"] + assert any("no packaged probe asset" in w for w in result.warnings) + + async def test_propagates_probe_exceptions(self) -> None: + """Engine raises → method raises. Lock is released even on exception.""" + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities() + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + side_effect=RuntimeError("engine boom"), + ), + ): + with pytest.raises(RuntimeError, match="engine boom"): + await service.validate_target_capabilities_async(target_registry_name="t1") + + # Lock must be released (the dict entry stays, but the lock isn't held). + lock = service._validate_locks["t1"] + assert not lock.locked(), "lock leaked after engine raised" + + async def test_serializes_concurrent_calls_on_same_target(self) -> None: + """ + Two concurrent calls on the same registry name within the same service + instance + same event loop serialize via the per-target lock. + """ + import asyncio + + service = TargetService() + fake_target = _fake_target_with_capabilities() + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + + first_running = asyncio.Event() + release_first = asyncio.Event() + order: list[str] = [] + + async def slow_first(**_kwargs): + order.append("first-enter") + first_running.set() + await release_first.wait() + order.append("first-exit") + return observed + + async def fast_second(**_kwargs): + order.append("second-enter") + order.append("second-exit") + return observed + + call_count = {"n": 0} + + async def dispatch(**kwargs): + call_count["n"] += 1 + return await (slow_first(**kwargs) if call_count["n"] == 1 else fast_second(**kwargs)) + + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new=dispatch, + ), + ): + task_first = asyncio.create_task(service.validate_target_capabilities_async(target_registry_name="t1")) + await first_running.wait() + task_second = asyncio.create_task(service.validate_target_capabilities_async(target_registry_name="t1")) + # Give scheduler a tick — second must NOT have started. + await asyncio.sleep(0.05) + assert "second-enter" not in order, f"second leaked through: {order}" + release_first.set() + await asyncio.gather(task_first, task_second) + + assert order == ["first-enter", "first-exit", "second-enter", "second-exit"] + + async def test_allows_concurrent_calls_on_different_targets(self) -> None: + """Two concurrent calls on different targets do NOT serialize.""" + import asyncio + + service = TargetService() + fake_a = _fake_target_with_capabilities() + fake_b = _fake_target_with_capabilities() + observed_a = _fake_observed_capabilities(declared=fake_a.capabilities) + observed_b = _fake_observed_capabilities(declared=fake_b.capabilities) + + a_running = asyncio.Event() + b_started = asyncio.Event() + + async def dispatch_a(**_kwargs): + a_running.set() + await b_started.wait() # must NOT block on B if locks are per-target + return observed_a + + async def dispatch_b(**_kwargs): + b_started.set() + return observed_b + + def get_target(*, target_registry_name: str): + return fake_a if target_registry_name == "a" else fake_b + + # Per-target dispatch via call_args inspection + async def probe(*, target, **kwargs): + if target is fake_a: + return await dispatch_a(**kwargs) + return await dispatch_b(**kwargs) + + with ( + patch.object(service, "get_target_object", side_effect=get_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new=probe, + ), + ): + task_a = asyncio.create_task(service.validate_target_capabilities_async(target_registry_name="a")) + await a_running.wait() + task_b = asyncio.create_task(service.validate_target_capabilities_async(target_registry_name="b")) + # If locks were shared, task_a would deadlock waiting on b_started. + result_a, result_b = await asyncio.wait_for(asyncio.gather(task_a, task_b), timeout=2.0) + assert result_a is not None and result_b is not None + + async def test_creates_fresh_lock_per_service_instance(self) -> None: + """ + Two TargetService() instances have independent _validate_locks dicts. + Guards the R5 instance-attribute fix against accidental re-promotion + to ClassVar (which would leak locks across pytest event loops). + """ + from unittest.mock import AsyncMock + + service_a = TargetService() + service_b = TargetService() + fake_target_a = _fake_target_with_capabilities() + fake_target_b = _fake_target_with_capabilities() + observed_a = _fake_observed_capabilities(declared=fake_target_a.capabilities) + observed_b = _fake_observed_capabilities(declared=fake_target_b.capabilities) + + # Trigger lock creation in both services for the same registry name + with ( + patch.object(service_a, "get_target_object", return_value=fake_target_a), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + return_value=observed_a, + ), + ): + await service_a.validate_target_capabilities_async(target_registry_name="shared") + + with ( + patch.object(service_b, "get_target_object", return_value=fake_target_b), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + return_value=observed_b, + ), + ): + await service_b.validate_target_capabilities_async(target_registry_name="shared") + + assert "shared" in service_a._validate_locks + assert "shared" in service_b._validate_locks + # Different lock objects per service instance. + assert service_a._validate_locks["shared"] is not service_b._validate_locks["shared"] + + async def test_includes_expected_warnings(self) -> None: + """All five base warnings are present, in the documented order.""" + from unittest.mock import AsyncMock + + service = TargetService() + fake_target = _fake_target_with_capabilities() + observed = _fake_observed_capabilities(declared=fake_target.capabilities) + with ( + patch.object(service, "get_target_object", return_value=fake_target), + patch( + "pyrit.backend.services.target_service.discover_target_capabilities_async", + new_callable=AsyncMock, + return_value=observed, + ), + ): + result = await service.validate_target_capabilities_async(target_registry_name="t1") + + assert result is not None + # Five base warnings (no 6th because no non-probeable modalities). + assert len(result.warnings) == 5 + joined = " | ".join(result.warnings) + assert "live requests" in joined # cost/side-effects + assert "capability_probe" in joined # memory tagging + assert "Output modalities are reported as declared" in joined + assert "semantic enforcement" in joined # request-vs-enforcement caveat + assert "Do not run Validate while an attack" in joined # validate-vs-attack