diff --git a/frontend/src/components/Chat/ConverterPanel.test.tsx b/frontend/src/components/Chat/ConverterPanel.test.tsx index 7c657eebdc..99922b4408 100644 --- a/frontend/src/components/Chat/ConverterPanel.test.tsx +++ b/frontend/src/components/Chat/ConverterPanel.test.tsx @@ -149,6 +149,29 @@ describe('ConverterPanel loading', () => { renderPanel() await waitFor(() => expect(screen.getByTestId('converter-panel-empty')).toBeInTheDocument()) }) + + it('hides base/helper converters that should not be offered', async () => { + const catalogWithHidden = { + items: [ + ...MOCK_CATALOG.items, + { + converter_type: 'SelectiveTextConverter', + supported_input_types: ['text'], + supported_output_types: ['text'], + parameters: [], + is_llm_based: false, + description: 'Base/helper converter.', + }, + ], + } + mockedConvertersApi.listConverterCatalog.mockResolvedValueOnce(catalogWithHidden as ConverterCatalogResponse) + renderPanel() + await waitForList() + + fireEvent.click(getComboboxInput()) + await waitFor(() => expect(screen.getByTestId('converter-option-Base64Converter')).toBeInTheDocument()) + expect(screen.queryByTestId('converter-option-SelectiveTextConverter')).not.toBeInTheDocument() + }) }) // ─── Close button ──────────────────────────────────────────────── diff --git a/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx b/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx index d212e97e94..a079202a84 100644 --- a/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx +++ b/frontend/src/components/Chat/ConverterPanel/ConverterPanel.tsx @@ -18,6 +18,10 @@ const PIECE_TYPE_LABELS: Record = { video: 'Video', } +// Converter classes the backend can build but that aren't useful to offer in the +// picker (base/helper classes). +const HIDDEN_CONVERTER_TYPES = new Set(['SelectiveTextConverter']) + interface ConverterPanelProps { onClose: () => void previewText?: string @@ -51,7 +55,7 @@ export default function ConverterPanel({ onClose, previewText = '', attachmentDa try { const response = await convertersApi.listConverterCatalog() - setConverters(response.items) + setConverters(response.items.filter((c) => !HIDDEN_CONVERTER_TYPES.has(c.converter_type))) } catch (err) { setConverters([]) setSelectedConverterType('') diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 8bd6199592..7fdd99c020 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -15,14 +15,12 @@ import base64 import inspect import mimetypes -import re import uuid from functools import lru_cache from pathlib import Path from typing import Any, Literal, Union, get_args, get_origin from urllib.parse import parse_qs, urlparse -from pyrit import prompt_converter from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.models.converters import ( ConverterCatalogEntry, @@ -38,9 +36,12 @@ ) from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType -from pyrit.prompt_converter import PromptConverter -from pyrit.prompt_target import PromptTarget -from pyrit.registry.object_registries import ConverterRegistry + +# ``get_union_non_none_args`` is a general type-introspection utility used here to +# render parameter types for the catalog (a presentation concern owned by this +# service). +from pyrit.registry.object_registries import ConverterParameterMetadata, ConverterRegistry +from pyrit.registry.resolution import get_union_non_none_args _DATA_TYPE_EXTENSION: dict[str, str] = { "image_path": ".png", @@ -50,169 +51,31 @@ } -def _build_converter_class_registry() -> dict[str, type]: - """ - Build a registry mapping converter class names to their classes. - - Uses the prompt_converter module's __all__ to discover all available converters. - - Returns: - Dict mapping class name (str) to class (type). - """ - registry: dict[str, type] = {} - for name in prompt_converter.__all__: - cls = getattr(prompt_converter, name, None) - if cls is not None and isinstance(cls, type) and issubclass(cls, PromptConverter): - registry[name] = cls - return registry - - -# Module-level class registry (built once on import) -_CONVERTER_CLASS_REGISTRY: dict[str, type] = _build_converter_class_registry() - -# Types that can be rendered as simple form fields -_SIMPLE_TYPES: set[type] = {str, int, float, bool} - - -def _is_simple_type(annotation: Any) -> bool: - """Return True if the annotation represents a type renderable in a form field.""" - if annotation in _SIMPLE_TYPES: - return True - origin = get_origin(annotation) - if origin is Literal: - return True - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - return len(non_none) == 1 and _is_simple_type(non_none[0]) - return False - - def _serialize_type(annotation: Any) -> str: """ - Convert a type annotation to a concise human-readable string. + Render a parameter's type annotation as a concise human-readable string. + + Used to populate the catalog DTO consumed by the frontend (e.g. ``"str"``, + ``"Optional[int]"``, ``"Literal['a', 'b']"``). Returns: str: A human-readable representation of the type annotation. """ if annotation is inspect.Parameter.empty: return "Any" - origin = get_origin(annotation) - if origin is Literal: + if get_origin(annotation) is Literal: args = get_args(annotation) return f"Literal[{', '.join(repr(a) for a in args)}]" - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - inner = _serialize_type(non_none[0]) - return f"Optional[{inner}]" if len(args) > len(non_none) else inner + non_none = get_union_non_none_args(annotation) + if non_none is not None and len(non_none) == 1: + inner = _serialize_type(non_none[0]) + has_none = type(None) in get_args(annotation) + return f"Optional[{inner}]" if has_none else inner if hasattr(annotation, "__name__"): return str(annotation.__name__) return str(annotation) -def _parse_arg_descriptions(converter_class: type) -> dict[str, str]: - """ - Parse parameter descriptions from Google-style docstring Args section. - - Returns: - dict[str, str]: Mapping of parameter names to their descriptions. - """ - doc = (converter_class.__init__.__doc__ or converter_class.__doc__ or "").strip() - match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) - if not match: - return {} - args_block = match.group(1) - # Detect indentation of first parameter line - indent_match = re.match(r"^(\s+)", args_block) - indent = indent_match.group(1) if indent_match else r"\s+" - pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" - descriptions: dict[str, str] = {} - for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): - descriptions[m.group(1)] = " ".join(m.group(2).split()) - return descriptions - - -def _extract_parameters(converter_class: type) -> list[ConverterParameterSchema]: - """ - Extract simple constructor parameters from a converter class. - - Returns: - list[ConverterParameterSchema]: List of parameter schemas. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError): - return [] - - arg_descriptions = _parse_arg_descriptions(converter_class) - - params: list[ConverterParameterSchema] = [] - for name, p in sig.parameters.items(): - if name in ("self", "args", "kwargs"): - continue - if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - continue - if not _is_simple_type(p.annotation): - continue - - no_default = p.default is inspect.Parameter.empty - is_sentinel = hasattr(p.default, "__class__") and "Sentinel" in type(p.default).__name__ - required = no_default or is_sentinel - - default_value: str | None = None - if not required and p.default is not None: - default_value = str(p.default) - - choices: list[str] | None = None - if get_origin(p.annotation) is Literal: - choices = [str(a) for a in get_args(p.annotation)] - - params.append( - ConverterParameterSchema( - name=name, - type_name=_serialize_type(p.annotation), - required=required, - default_value=default_value, - choices=choices, - description=arg_descriptions.get(name), - ) - ) - - return params - - -def _is_llm_based(converter_class: type) -> bool: - """ - Check if the converter requires a target parameter. - - Matches any converter whose ``__init__`` accepts - a ``PromptTarget`` (or subclass) parameter. - These converters perform LLM-based transformations and should not automatically be applied - - Returns: - bool: True if the converter is LLM-based, False otherwise. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError): - return False - - for name, p in sig.parameters.items(): - if name == "self": - continue - ann = p.annotation - if ann is inspect.Parameter.empty: - continue - try: - if isinstance(ann, type) and issubclass(ann, PromptTarget): - return True - except TypeError: - continue - return False - - class ConverterService: """ Service for managing converter instances. @@ -255,43 +118,49 @@ async def list_converters_async(self) -> ConverterInstanceListResponse: async def list_converter_catalog_async(self) -> ConverterCatalogResponse: """ - List all available converter types from the backend converter registry. + List all available converter types from the converter class registry. + + Returns every constructible converter. Deciding which entries to surface + to a user is a presentation concern owned by the caller (e.g. the + frontend), not this service. Returns: ConverterCatalogResponse containing all available converter classes. """ - items: list[ConverterCatalogEntry] = [] - for converter_type, converter_class in sorted(_CONVERTER_CLASS_REGISTRY.items()): - if ( - converter_type in ("PromptConverter", "ConverterResult", "SelectiveTextConverter") - or "Strategy" in converter_type - ): - continue - - supported_input_types = [ - str(data_type) for data_type in getattr(converter_class, "SUPPORTED_INPUT_TYPES", ()) - ] - supported_output_types = [ - str(data_type) for data_type in getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ()) - ] - - # Extract first paragraph of docstring as description - raw_doc = (converter_class.__doc__ or "").strip() - description = raw_doc.split("\n\n")[0].replace("\n", " ").strip() or None - - items.append( - ConverterCatalogEntry( - converter_type=converter_type, - supported_input_types=supported_input_types, - supported_output_types=supported_output_types, - parameters=_extract_parameters(converter_class), - is_llm_based=_is_llm_based(converter_class), - description=description, - ) + items: list[ConverterCatalogEntry] = [ + ConverterCatalogEntry( + converter_type=metadata.class_name, + supported_input_types=list(metadata.supported_input_types), + supported_output_types=list(metadata.supported_output_types), + parameters=[self._build_parameter_schema(p) for p in metadata.parameters if p.coercible_from_string], + is_llm_based=metadata.is_llm_based, + description=metadata.class_description or None, ) + for metadata in self._registry.list_class_metadata() + ] return ConverterCatalogResponse(items=items) + @staticmethod + def _build_parameter_schema(parameter: ConverterParameterMetadata) -> ConverterParameterSchema: + """ + Map registry parameter metadata to the catalog DTO. + + Renders the raw annotation to a human-readable ``type_name`` for the + frontend (presentation concern owned by this service). + + Returns: + ConverterParameterSchema: The parameter schema for the catalog entry. + """ + return ConverterParameterSchema( + name=parameter.name, + type_name=_serialize_type(parameter.annotation), + required=parameter.required, + default_value=parameter.default_value, + choices=list(parameter.choices) if parameter.choices is not None else None, + description=parameter.description, + ) + async def get_converter_async(self, *, converter_id: str) -> ConverterInstance | None: """ Get a converter instance by ID. @@ -331,12 +200,16 @@ async def create_converter_async(self, *, request: CreateConverterRequest) -> Cr """ converter_id = str(uuid.uuid4()) - # Resolve any converter references in params and instantiate + # Resolve any converter references in params, persist data-URI params to + # disk (frontend concern), then delegate construction (incl. param + # coercion) to the converter registry. params = self._resolve_converter_params(params=request.params) - converter_class = self._get_converter_class(converter_type=request.type) - params = self._coerce_params(converter_class=converter_class, params=params) + try: + converter_class = self._registry.get_class(request.type) + except KeyError as e: + raise ValueError(f"Converter type '{request.type}' not found") from e params = await self._persist_data_uri_params_async(converter_class=converter_class, params=params) - converter_obj = converter_class(**params) + converter_obj = self._registry.create_instance(request.type, **params) self._registry.register_instance(converter_obj, name=converter_id) return CreateConverterResponse( @@ -431,29 +304,6 @@ def get_converter_objects_for_ids(self, *, converter_ids: list[str]) -> list[Any # Private Helper Methods # ======================================================================== - def _get_converter_class(self, *, converter_type: str) -> type: - """ - Get the converter class for a given type name. - - Looks up the class in the module-level converter class registry. - - Args: - converter_type: The exact class name of the converter (e.g., 'Base64Converter'). - - Returns: - The converter class. - - Raises: - ValueError: If the converter type is not found. - """ - cls = _CONVERTER_CLASS_REGISTRY.get(converter_type) - if cls is None: - raise ValueError( - f"Converter type '{converter_type}' not found. " - f"Available types: {sorted(_CONVERTER_CLASS_REGISTRY.keys())}" - ) - return cls - def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any]: """ Resolve converter references in params. @@ -474,53 +324,6 @@ def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any] resolved["converter"] = conv_obj return resolved - @staticmethod - def _coerce_params(*, converter_class: type, params: dict[str, Any]) -> dict[str, Any]: - """ - Coerce parameter values to match the converter's __init__ type annotations. - - The frontend sends all values as strings; this converts them to int, float, - or bool as needed based on the constructor signature. - - Returns: - Params dict with values coerced to the expected types. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError) as e: - raise ValueError( - f"Failed to inspect __init__ signature for converter '{converter_class.__name__}': {e}" - ) from e - - coerced = dict(params) - for name, value in coerced.items(): - if name not in sig.parameters or not isinstance(value, str): - continue - annotation = sig.parameters[name].annotation - if annotation is inspect.Parameter.empty: - continue - - origin = get_origin(annotation) - # Unwrap X | None to X - if origin is Union: - args = get_args(annotation) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - annotation = non_none[0] - origin = get_origin(annotation) - - try: - if annotation is int: - coerced[name] = int(value) - elif annotation is float: - coerced[name] = float(value) - elif annotation is bool: - coerced[name] = value.lower() in ("true", "1", "yes") - except (ValueError, TypeError) as e: - raise ValueError(f"Parameter '{name}' expects {annotation.__name__}, got {value!r}") from e - - return coerced - @staticmethod async def _persist_data_uri_params_async( *, diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index c8236aceda..b4997828a0 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -94,9 +94,7 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSu init_kwargs = self._build_init_kwargs( request=request, scenario_class=scenario_class, objective_target=objective_target ) - scenario = await self._initialize_scenario_async( - request=request, scenario_class=scenario_class, init_kwargs=init_kwargs - ) + scenario = await self._initialize_scenario_async(request=request, init_kwargs=init_kwargs) except Exception: self._run_semaphore.release() raise @@ -371,15 +369,13 @@ def _build_init_kwargs( return init_kwargs - async def _initialize_scenario_async( - self, *, request: RunScenarioRequest, scenario_class: type[Scenario], init_kwargs: dict[str, Any] - ) -> Scenario: + async def _initialize_scenario_async(self, *, request: RunScenarioRequest, init_kwargs: dict[str, Any]) -> Scenario: """ Instantiate the scenario and call initialize_async. Args: - request: The run request (for scenario_params and scenario_result_id). - scenario_class: The resolved scenario class. + request: The run request (for scenario_name, scenario_params, and + scenario_result_id). init_kwargs: The kwargs to pass to scenario.initialize_async. Returns: @@ -388,7 +384,8 @@ async def _initialize_scenario_async( constructor_kwargs: dict[str, Any] = {} if request.scenario_result_id: constructor_kwargs["scenario_result_id"] = request.scenario_result_id - scenario = scenario_class(**constructor_kwargs) # type: ignore[call-arg] + scenario_registry = ScenarioRegistry.get_registry_singleton() + scenario = scenario_registry.create_instance(request.scenario_name, **constructor_kwargs) scenario.set_params_from_args(args=request.scenario_params or {}) await scenario.initialize_async(**init_kwargs) return scenario @@ -488,7 +485,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_version=scenario_result.scenario_identifier.version, status=status, created_at=scenario_result.creation_time, - updated_at=scenario_result.completion_time, + updated_at=scenario_result.completion_time or scenario_result.creation_time, error=error, error_type=error_type, strategies_used=strategies_used, diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index cd94382b93..05c1e9eb46 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -4,6 +4,7 @@ """Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol +from pyrit.registry.buildable_registry import BuildableRegistry from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, @@ -13,6 +14,7 @@ ScenarioParameterMetadata, ScenarioRegistry, ) +from pyrit.registry.container_registry import ContainerRegistry from pyrit.registry.discovery import ( discover_in_directory, discover_in_package, @@ -21,6 +23,8 @@ from pyrit.registry.object_registries import ( AttackTechniqueRegistry, BaseInstanceRegistry, + ConverterMetadata, + ConverterParameterMetadata, ConverterRegistry, RegistryEntry, RetrievableInstanceRegistry, @@ -33,6 +37,10 @@ "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", + "BuildableRegistry", + "ContainerRegistry", + "ConverterMetadata", + "ConverterParameterMetadata", "ConverterRegistry", "RetrievableInstanceRegistry", "ClassEntry", diff --git a/pyrit/registry/buildable_registry.py b/pyrit/registry/buildable_registry.py new file mode 100644 index 0000000000..07fd5df91f --- /dev/null +++ b/pyrit/registry/buildable_registry.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Buildable registry base for PyRIT. + +``BuildableRegistry`` is the universal registry capability: discover classes, +introspect them into metadata, and **build** configured instances from a type +name plus a flat argument dict. Construction routes through the shared +``resolve_constructor_args`` primitive, so simple values are coerced and +registry-reference parameters (e.g. a ``PromptTarget``) are resolved by name — +the same mechanism for every domain. + +Every PyRIT registry is buildable. Registries that additionally hold named +instances extend ``ContainerRegistry`` (which adds the instance container on top +of this layer). +""" + +from __future__ import annotations + +from typing import TypeVar + +from pyrit.registry.class_registries.base_class_registry import BaseClassRegistry +from pyrit.registry.resolution import resolve_constructor_args + +T = TypeVar("T") +MetadataT = TypeVar("MetadataT") + + +class BuildableRegistry(BaseClassRegistry[T, MetadataT]): + """ + Registry base that can build instances from a type name and arguments. + + Extends the class-table infrastructure of ``BaseClassRegistry`` with a + construction path that routes through ``resolve_constructor_args``: string + values are coerced to their annotated scalar types and registry-reference + parameters are resolved by name from the owning domain's registry. A + registered factory, when present, is used as-is (its arguments are not + resolved, since a factory owns its own construction semantics). + + Type Parameters: + T: The type of classes being registered (e.g. ``PromptConverter``). + MetadataT: The metadata dataclass type (e.g. ``ConverterMetadata``). + """ + + def get_class_names(self) -> list[str]: + """ + Get a sorted list of all registered class names. + + Always reflects the class catalog, even on container registries where the + protocol surface (``get_names``) refers to instances. + + Returns: + list[str]: The sorted class-catalog names. + """ + self._ensure_discovered() + return sorted(self._class_entries.keys()) + + def get_class(self, name: str) -> type[T]: + """ + Get a registered class by its catalog name. + + Overrides the base lookup so the "not found" error lists the class catalog + (``get_class_names``) rather than the instance container that a + ``ContainerRegistry`` exposes through ``get_names``. + + Args: + name (str): The class-catalog name to resolve. + + Returns: + type[T]: The registered class. + + Raises: + KeyError: If the name is not registered in the class catalog. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + return entry.registered_class + + def list_class_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[MetadataT]: + """ + List metadata for all registered classes, optionally filtered. + + This is the class-catalog metadata (one entry per registered class), + distinct from any instance-level metadata a container registry exposes. + It always reflects the class catalog, even on container registries where + ``list_metadata`` refers to instances. + + Args: + include_filters (dict[str, object] | None): Filters items must match. + exclude_filters (dict[str, object] | None): Filters items must not match. + + Returns: + list[MetadataT]: Metadata describing each registered class. + """ + return BaseClassRegistry.list_metadata(self, include_filters=include_filters, exclude_filters=exclude_filters) + + def create_instance(self, name: str, **kwargs: object) -> T: + """ + Build a configured instance by class name. + + Arguments are resolved via ``resolve_constructor_args`` (coerce simple + strings, resolve registry references by name, raise on unknown params). + When the class is registered with a factory, the factory is invoked + directly with the given arguments instead. + + Args: + name (str): The class-catalog name to build. + **kwargs (object): Constructor arguments (simple values or registry + names for reference parameters). + + Returns: + T: The constructed instance. + + Raises: + KeyError: If the name is not registered. + ValueError: If an argument is not a valid constructor parameter, a + registry reference cannot be resolved, or a value cannot be coerced. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + + if entry.factory is not None: + return entry.create_instance(**kwargs) + + raw_args = {**entry.default_kwargs, **kwargs} + resolved = resolve_constructor_args(cls=entry.registered_class, raw_args=raw_args) + return entry.registered_class(**resolved) diff --git a/pyrit/registry/container_registry.py b/pyrit/registry/container_registry.py new file mode 100644 index 0000000000..95f3ec9f21 --- /dev/null +++ b/pyrit/registry/container_registry.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Container registry base for PyRIT. + +``ContainerRegistry`` extends ``BuildableRegistry`` with an instance container: +in addition to discovering classes and building instances (the buildable layer), +it holds named, pre-configured instances that callers register and retrieve. +This is the base for domains that are both buildable *and* hold instances +(converters, targets, scorers). + +The container is the registry's primary identity: the protocol surface +(``get_names``, ``__contains__``, ``__len__``, ``__iter__``, ``list_metadata``) +refers to **instances**. The class catalog is reached through the explicitly +named buildable methods (``get_class``, ``get_class_names``, +``list_class_metadata``, ``create_instance``). This keeps name-based resolution +consistent across every container registry. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from pyrit.models import ComponentIdentifier, Identifiable +from pyrit.registry.buildable_registry import BuildableRegistry + +if TYPE_CHECKING: + from collections.abc import Iterator + + from pyrit.registry.object_registries.base_instance_registry import RegistryEntry + +T = TypeVar("T", bound=Identifiable) +MetadataT = TypeVar("MetadataT") + + +class ContainerRegistry(BuildableRegistry[T, MetadataT], Generic[T, MetadataT]): + """ + Registry base that is buildable *and* holds named instances. + + Adds an instance container on top of ``BuildableRegistry``: register + pre-configured instances, retrieve them by name, list and tag them. Stored + instances must implement ``Identifiable`` so instance metadata can be derived + from ``get_identifier()``. + + The container is primary: ``get_names``/``__contains__``/``__len__``/ + ``__iter__``/``list_metadata`` operate on instances. The class catalog is + accessed via the buildable methods inherited from ``BuildableRegistry``. + + Type Parameters: + T: The type of instances held (must be ``Identifiable``). + MetadataT: The class-catalog metadata type. + """ + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry. + + Args: + lazy_discovery (bool): If True, class discovery is deferred until first + access. If False, discovery runs immediately. + """ + super().__init__(lazy_discovery=lazy_discovery) + self._instance_entries: dict[str, RegistryEntry[T]] = {} + self._instance_metadata_cache: list[ComponentIdentifier] | None = None + + @staticmethod + def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: + """ + Normalize tags into a ``dict[str, str]``. + + Args: + tags (dict[str, str] | list[str] | None): Tags as a dict, a list of + string keys (values default to ``""``), or None (empty dict). + + Returns: + dict[str, str]: The normalized tags. + """ + if tags is None: + return {} + if isinstance(tags, list): + return dict.fromkeys(tags, "") + return dict(tags) + + def register_instance( + self, + instance: T, + *, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """ + Register a pre-configured instance in the container. + + Args: + instance (T): The instance to register. + name (str | None): The registry name. Defaults to the instance's + identifier ``unique_name``. + tags (dict[str, str] | list[str] | None): Optional tags for + categorization. + metadata (dict[str, Any] | None): Optional per-entry metadata. + """ + if name is None: + name = instance.get_identifier().unique_name + + from pyrit.registry.object_registries.base_instance_registry import RegistryEntry + + self._instance_entries[name] = RegistryEntry( + name=name, + instance=instance, + tags=self._normalize_tags(tags), + metadata=metadata or {}, + ) + self._instance_metadata_cache = None + + def get_instance_by_name(self, name: str) -> T | None: + """ + Get a registered instance by name. + + Args: + name (str): The registry name of the instance. + + Returns: + T | None: The instance, or None if not found. + """ + entry = self._instance_entries.get(name) + return entry.instance if entry is not None else None + + def get_instance_entry(self, name: str) -> RegistryEntry[T] | None: + """ + Get the full instance entry (including tags) by name. + + Args: + name (str): The registry name of the entry. + + Returns: + RegistryEntry[T] | None: The entry, or None if not found. + """ + return self._instance_entries.get(name) + + def get_all_instances(self) -> list[RegistryEntry[T]]: + """ + Get all registered instance entries sorted by name. + + Returns: + list[RegistryEntry[T]]: The instance entries sorted by name. + """ + return [self._instance_entries[name] for name in sorted(self._instance_entries.keys())] + + def get_by_tag(self, *, tag: str, value: str | None = None) -> list[RegistryEntry[T]]: + """ + Get instance entries that carry a given tag, optionally matching a value. + + Args: + tag (str): The tag key to match. + value (str | None): If provided, only entries whose tag value equals + this are returned. If None, any entry with the tag key matches. + + Returns: + list[RegistryEntry[T]]: Matching entries sorted by name. + """ + results: list[RegistryEntry[T]] = [] + for name in sorted(self._instance_entries.keys()): + entry = self._instance_entries[name] + if tag in entry.tags and (value is None or entry.tags[tag] == value): + results.append(entry) + return results + + def add_tags(self, *, name: str, tags: dict[str, str] | list[str]) -> None: + """ + Add tags to an existing instance entry. + + Args: + name (str): The registry name of the entry to tag. + tags (dict[str, str] | list[str]): Tags to add. + + Raises: + KeyError: If no entry with the given name exists. + """ + entry = self._instance_entries.get(name) + if entry is None: + raise KeyError(f"No instance named '{name}' in registry.") + entry.tags.update(self._normalize_tags(tags)) + self._instance_metadata_cache = None + + def list_instance_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: + """ + List metadata for all registered instances, optionally filtered. + + Args: + include_filters (dict[str, object] | None): Filters items must match. + exclude_filters (dict[str, object] | None): Filters items must not match. + + Returns: + list[ComponentIdentifier]: The identifier metadata for each instance. + """ + from pyrit.registry.base import _matches_filters + + if self._instance_metadata_cache is None: + self._instance_metadata_cache = [ + self._instance_entries[name].instance.get_identifier() for name in sorted(self._instance_entries.keys()) + ] + + if not include_filters and not exclude_filters: + return self._instance_metadata_cache + + return [ + m + for m in self._instance_metadata_cache + if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + # ------------------------------------------------------------------ + # Protocol surface — operates on the instance container (primary identity) + # ------------------------------------------------------------------ + + def get_names(self) -> list[str]: + """ + Get a sorted list of all registered instance names. + + Returns: + list[str]: The instance names sorted alphabetically. + """ + return sorted(self._instance_entries.keys()) + + def list_metadata( # type: ignore[ty:invalid-method-override] + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: + """ + List instance metadata (the container is the primary identity). + + Intentionally narrows the return type to instance ``ComponentIdentifier`` + metadata: on a container registry the protocol surface refers to + instances. Class-catalog metadata is available via ``list_class_metadata``. + + Args: + include_filters (dict[str, object] | None): Filters items must match. + exclude_filters (dict[str, object] | None): Filters items must not match. + + Returns: + list[ComponentIdentifier]: The identifier metadata for each instance. + """ + return self.list_instance_metadata(include_filters=include_filters, exclude_filters=exclude_filters) + + def __contains__(self, name: str) -> bool: + """ + Check if an instance name is registered. + + Returns: + bool: True if the instance name is registered, False otherwise. + """ + return name in self._instance_entries + + def __len__(self) -> int: + """ + Get the count of registered instances. + + Returns: + int: The number of registered instances. + """ + return len(self._instance_entries) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered instance names. + + Returns: + Iterator[str]: An iterator over sorted instance names. + """ + return iter(sorted(self._instance_entries.keys())) diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index b6edf16088..fd2fb9cb01 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -19,6 +19,8 @@ RegistryEntry, ) from pyrit.registry.object_registries.converter_registry import ( + ConverterMetadata, + ConverterParameterMetadata, ConverterRegistry, ) from pyrit.registry.object_registries.retrievable_instance_registry import ( @@ -39,6 +41,8 @@ # Concrete registries "AttackTechniqueRegistry", "ConverterRegistry", + "ConverterMetadata", + "ConverterParameterMetadata", "ScorerRegistry", "TargetRegistry", ] diff --git a/pyrit/registry/object_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py index 568d1e6332..28cf3e9d7e 100644 --- a/pyrit/registry/object_registries/converter_registry.py +++ b/pyrit/registry/object_registries/converter_registry.py @@ -2,21 +2,36 @@ # Licensed under the MIT license. """ -Converter registry for managing PyRIT converter instances. +Converter registry for PyRIT. -Converters are registered explicitly via initializers as pre-configured instances. +A single registry for ``PromptConverter`` that both: -NOTE: This is a placeholder implementation. A full implementation will be added soon. +- **builds** converters from a type name plus arguments — discovering converter + classes, introspecting their constructor parameters, and constructing instances + via the shared resolver (so LLM converters can be built by passing a + ``converter_target`` registry name), and +- **holds** pre-configured converter instances registered via initializers or the + backend. + +It extends ``ContainerRegistry``: the class catalog is reached through the +buildable methods (``get_class``, ``list_class_metadata``, +``create_instance``) while the instance container is the primary surface +(``register_instance``, ``get_instance_by_name``, ``get_all_instances``, +``get_names``). """ from __future__ import annotations +import inspect import logging -from typing import TYPE_CHECKING +import re +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, get_args, get_origin -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.class_registries.base_class_registry import ClassEntry +from pyrit.registry.container_registry import ContainerRegistry +from pyrit.registry.resolution import get_union_non_none_args, is_coercible_from_string if TYPE_CHECKING: from pyrit.prompt_converter import PromptConverter @@ -24,46 +39,230 @@ logger = logging.getLogger(__name__) -class ConverterRegistry(RetrievableInstanceRegistry["PromptConverter"]): +class ConverterParameterMetadata(NamedTuple): + """ + A converter constructor parameter described for dynamic construction. + + Carries raw introspection data so callers can build converters on the fly. + ``annotation`` is the parameter's raw type annotation; rendering it to a + human-readable string is a presentation concern left to the caller. + ``coercible_from_string`` is True when a string value can be coerced to the + annotated type. ``requires_llm`` is True when the parameter expects a + ``PromptTarget`` (i.e. the converter performs an LLM-based transformation). + + NamedTuple so consumers can read fields by name while the value stays + immutable (safe to cache inside a frozen ``ConverterMetadata``). + """ + + name: str + annotation: Any + required: bool + default_value: str | None + choices: tuple[str, ...] | None + description: str | None + coercible_from_string: bool + requires_llm: bool + + +@dataclass(frozen=True) +class ConverterMetadata(ClassRegistryEntry): + """ + Metadata describing a registered ``PromptConverter`` class. + + Use ``ConverterRegistry.get_class()`` to get the actual class or + ``create_instance()`` to build a configured instance. + """ + + # Input data types the converter accepts (stringified PromptDataType values). + supported_input_types: tuple[str, ...] = field(kw_only=True, default=()) + + # Output data types the converter produces (stringified PromptDataType values). + supported_output_types: tuple[str, ...] = field(kw_only=True, default=()) + + # Simple constructor parameters suitable for dynamic form generation. + parameters: tuple[ConverterParameterMetadata, ...] = field(kw_only=True, default=()) + + # Whether the converter requires an LLM target. + is_llm_based: bool = field(kw_only=True, default=False) + + +def _requires_llm_target(annotation: Any) -> bool: + """ + Return True if the annotation expects a ``PromptTarget`` (or subclass). + + Handles unioned forms such as ``PromptTarget | None``. A converter parameter + with such an annotation indicates the converter performs an LLM-based + transformation. + + Returns: + bool: True if the annotation expects a ``PromptTarget``, False otherwise. + """ + if annotation is inspect.Parameter.empty: + return False + + from pyrit.prompt_target import PromptTarget + + candidates = get_union_non_none_args(annotation) + if candidates is None: + candidates = [annotation] + for candidate in candidates: + try: + if isinstance(candidate, type) and issubclass(candidate, PromptTarget): + return True + except TypeError: + continue + return False + + +def _parse_arg_descriptions(converter_class: type) -> dict[str, str]: + """ + Parse parameter descriptions from a Google-style docstring Args section. + + Returns: + dict[str, str]: Mapping of parameter names to their descriptions. + """ + doc = (converter_class.__init__.__doc__ or converter_class.__doc__ or "").strip() + match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) + if not match: + return {} + args_block = match.group(1) + # Detect indentation of first parameter line + indent_match = re.match(r"^(\s+)", args_block) + indent = indent_match.group(1) if indent_match else r"\s+" + pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" + descriptions: dict[str, str] = {} + for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): + descriptions[m.group(1)] = " ".join(m.group(2).split()) + return descriptions + + +def _extract_parameters(converter_class: type) -> tuple[ConverterParameterMetadata, ...]: + """ + Extract constructor parameters from a converter class. + + Surfaces every settable constructor parameter (excluding ``self`` and + var-args) so a caller has the full picture for dynamic construction. Each + parameter records its raw ``annotation`` and a ``coercible_from_string`` flag + indicating whether a string value can be coerced to its type. + + Returns: + tuple[ConverterParameterMetadata, ...]: The constructor parameters. + """ + try: + sig = inspect.signature(converter_class.__init__) + except (ValueError, TypeError): + return () + + arg_descriptions = _parse_arg_descriptions(converter_class) + + params: list[ConverterParameterMetadata] = [] + for name, p in sig.parameters.items(): + if name in ("self", "args", "kwargs"): + continue + if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + continue + + no_default = p.default is inspect.Parameter.empty + is_sentinel = hasattr(p.default, "__class__") and "Sentinel" in type(p.default).__name__ + required = no_default or is_sentinel + + default_value: str | None = None + if not required and p.default is not None: + default_value = str(p.default) + + choices: tuple[str, ...] | None = None + choice_annotation = p.annotation + non_none_choice = get_union_non_none_args(choice_annotation) + if non_none_choice is not None and len(non_none_choice) == 1: + choice_annotation = non_none_choice[0] + if get_origin(choice_annotation) is Literal: + choices = tuple(str(a) for a in get_args(choice_annotation)) + + params.append( + ConverterParameterMetadata( + name=name, + annotation=p.annotation, + required=required, + default_value=default_value, + choices=choices, + description=arg_descriptions.get(name), + coercible_from_string=is_coercible_from_string(p.annotation), + requires_llm=_requires_llm_target(p.annotation), + ) + ) + + return tuple(params) + + +class ConverterRegistry(ContainerRegistry["PromptConverter", ConverterMetadata]): """ - Registry for managing available converter instances. + Registry that discovers, builds, and holds ``PromptConverter`` instances. - This registry stores pre-configured PromptConverter instances (not classes). - Converters are registered explicitly via initializers after being instantiated - with their required parameters. + Discovers all concrete ``PromptConverter`` subclasses exported from + ``pyrit.prompt_converter`` (keyed by their exact class name, e.g. + ``"Base64Converter"``) for the buildable catalog, and holds pre-configured + instances registered via initializers or the backend. + + Building a converter resolves its arguments through the shared resolver, so + LLM converters can be constructed by passing a ``converter_target`` that names + a target in the ``TargetRegistry``. """ - def register_instance( - self, - converter: PromptConverter, - *, - name: str | None = None, - tags: dict[str, str] | list[str] | None = None, - ) -> None: + def _get_registry_name(self, cls: type) -> str: """ - Register a converter instance. + Use the exact class name as the catalog key. - Args: - converter: The pre-configured converter instance (not a class). - name: Optional custom registry name. If not provided, - derived from the converter's unique identifier. - tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). + Converters are referenced by their class name (e.g. ``"Base64Converter"``) + rather than the snake_case default used by other class registries. + + Returns: + str: The class name. """ - if name is None: - name = converter.get_identifier().unique_name + return cls.__name__ + + def _discover(self) -> None: + """Discover all concrete ``PromptConverter`` subclasses from ``pyrit.prompt_converter``.""" + from pyrit import prompt_converter + from pyrit.prompt_converter import PromptConverter - self.register(converter, name=name, tags=tags) - logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") + for name in prompt_converter.__all__: + cls = getattr(prompt_converter, name, None) + if cls is None or not isinstance(cls, type): + continue + if not issubclass(cls, PromptConverter) or cls is PromptConverter: + continue + self._class_entries[name] = ClassEntry(registered_class=cls) + logger.debug(f"Registered converter class: {name}") - def get_instance_by_name(self, name: str) -> PromptConverter | None: + def _build_metadata(self, name: str, entry: ClassEntry[PromptConverter]) -> ConverterMetadata: """ - Get a registered converter instance by name. + Build catalog metadata for a ``PromptConverter`` class. Args: - name: The registry name of the converter. + name (str): The catalog name (exact class name) of the converter. + entry (ClassEntry[PromptConverter]): The class entry being described. Returns: - The converter instance, or None if not found. + ConverterMetadata: Metadata describing the converter class. """ - return self.get(name) + converter_class = entry.registered_class + + # First paragraph of the docstring as a short description. + raw_doc = (converter_class.__doc__ or "").strip() + description = raw_doc.split("\n\n")[0].replace("\n", " ").strip() + + supported_input_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_INPUT_TYPES", ())) + supported_output_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ())) + + parameters = _extract_parameters(converter_class) + + return ConverterMetadata( + class_name=converter_class.__name__, + class_module=converter_class.__module__, + class_description=description, + registry_name=name, + supported_input_types=supported_input_types, + supported_output_types=supported_output_types, + parameters=parameters, + is_llm_based=any(p.requires_llm for p in parameters), + ) diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py new file mode 100644 index 0000000000..511e579821 --- /dev/null +++ b/pyrit/registry/resolution.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Constructor-argument resolution for PyRIT registries. + +This is the shared mechanism that lets any registry build an instance from a +type name plus a flat dict of arguments. Build inputs are exactly two kinds: + +- **Simple values** — strings/ints/floats/bools (and ``Literal`` choices) that + can be coerced to the constructor's annotated type. +- **Registry references** — a parameter whose annotation is a domain base type + (``PromptTarget``, ``PromptConverter``, ``Scorer``) is supplied *by name* and + resolved from that domain's registry. An already-constructed instance passes + through unchanged. + +Unknown parameters raise, so a caller (form, agent, attack strategy) gets a +clear error instead of having values silently dropped. + +This module performs no eager heavy imports and never imports ``pyrit.backend``: +the resolvable-registry lookups are done lazily so it can be reused anywhere. +""" + +from __future__ import annotations + +import inspect +import types +from typing import TYPE_CHECKING, Any, Literal, Protocol, Union, get_args, get_origin + +if TYPE_CHECKING: + from collections.abc import Callable + +# Scalar Python types whose string values can be coerced to the real type. +_SIMPLE_TYPES: set[type] = {str, int, float, bool} + + +class _NamedInstanceRegistry(Protocol): + """Structural type for a registry that resolves stored instances by name.""" + + def get_instance_by_name(self, name: str) -> Any | None: + """Return the instance registered under ``name``, or None.""" + ... + + def get_names(self) -> list[str]: + """Return the sorted names of registered instances.""" + ... + + +def get_union_non_none_args(annotation: Any) -> list[Any] | None: + """ + Return the non-``None`` members of a union annotation, or None if not a union. + + Handles both ``typing.Union[X, None]`` and PEP 604 ``X | None``. This is a + general type-introspection utility (not presentation), reused by coercion, + registry-reference detection, and callers that need to render a type. + + Args: + annotation (Any): The type annotation to inspect. + + Returns: + list[Any] | None: The non-None union members, or None when the annotation + is not a union. + """ + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + return [a for a in get_args(annotation) if a is not type(None)] + return None + + +def is_coercible_from_string(annotation: Any) -> bool: + """ + Return True if a string value can be coerced to the annotated type. + + Covers the scalar types in ``_SIMPLE_TYPES`` (str/int/float/bool), + ``Literal`` annotations, and an ``Optional`` wrapping one of those. + + Returns: + bool: True if the annotation is coercible from a string, False otherwise. + """ + if annotation in _SIMPLE_TYPES: + return True + if get_origin(annotation) is Literal: + return True + non_none = get_union_non_none_args(annotation) + if non_none is not None: + return len(non_none) == 1 and is_coercible_from_string(non_none[0]) + return False + + +def _resolvable_registries() -> list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: + """ + Return the (base type -> registry singleton getter) pairs that can be resolved by name. + + A constructor parameter whose annotation is (a subclass of) one of these base + types is supplied by name and looked up in the paired registry. Imports are + deferred so this core module stays import-light and free of cycles. + + Returns: + list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: The resolvable + domain base types paired with a callable returning their registry singleton. + """ + from pyrit.prompt_converter import PromptConverter + from pyrit.prompt_target import PromptTarget + from pyrit.registry.object_registries import ( + ConverterRegistry, + ScorerRegistry, + TargetRegistry, + ) + from pyrit.score.scorer import Scorer + + return [ + (PromptTarget, TargetRegistry.get_registry_singleton), + (PromptConverter, ConverterRegistry.get_registry_singleton), + (Scorer, ScorerRegistry.get_registry_singleton), + ] + + +def get_resolvable_registry_getter(annotation: Any) -> Callable[[], _NamedInstanceRegistry] | None: + """ + Return the registry-singleton getter for a registry-reference annotation. + + The annotation matches when it is (or unions, e.g. ``X | None``, to) a subclass + of a resolvable domain base type. A parameter with such an annotation is + supplied by name and resolved from the returned registry. + + Args: + annotation (Any): The parameter's type annotation. + + Returns: + Callable[[], _NamedInstanceRegistry] | None: A callable returning the + registry singleton, or None when the annotation is not a registry reference. + """ + if annotation is inspect.Parameter.empty: + return None + + candidates = get_union_non_none_args(annotation) + if candidates is None: + candidates = [annotation] + + for base_type, getter in _resolvable_registries(): + for candidate in candidates: + try: + if isinstance(candidate, type) and issubclass(candidate, base_type): + return getter + except TypeError: + continue + return None + + +def is_registry_reference(annotation: Any) -> bool: + """ + Return True if the annotation is a registry reference (resolved by name). + + Returns: + bool: True if a value for this parameter is supplied by name and resolved + from a registry, False otherwise. + """ + return get_resolvable_registry_getter(annotation) is not None + + +def coerce_string_to_annotation(*, value: str, annotation: Any) -> Any: + """ + Coerce a string value to the annotated scalar type (int/float/bool/Literal). + + ``Optional[X]`` / ``X | None`` is unwrapped to ``X`` first. A ``Literal`` value + is validated against the allowed members and returned as the matching member + (so an int literal comes back as an ``int``); other ``str`` values pass through + unchanged. + + Args: + value (str): The raw string value. + annotation (Any): The parameter's type annotation. + + Returns: + Any: The value coerced to the annotated type, or the original string when + no numeric/boolean/Literal coercion applies. + + Raises: + ValueError: If the value cannot be interpreted as the annotated type, or is + not one of the allowed members of an annotated ``Literal``. + """ + if annotation is inspect.Parameter.empty: + return value + + non_none = get_union_non_none_args(annotation) + if non_none is not None and len(non_none) == 1: + annotation = non_none[0] + + if get_origin(annotation) is Literal: + allowed = get_args(annotation) + for member in allowed: + if value == str(member): + return member + raise ValueError(f"expected one of {[str(a) for a in allowed]}, got {value!r}") + + if annotation is int: + return int(value) + if annotation is float: + return float(value) + if annotation is bool: + lowered = value.strip().lower() + if lowered in ("true", "1", "yes"): + return True + if lowered in ("false", "0", "no"): + return False + raise ValueError(f"cannot interpret {value!r} as a boolean") + return value + + +def _resolve_registry_reference( + *, value: Any, getter: Callable[[], _NamedInstanceRegistry], owner: str, name: str +) -> Any: + """ + Resolve a registry-reference parameter value to a stored instance. + + A string value is looked up by name in the paired registry. An already-built + instance passes through unchanged. + + Args: + value (Any): The raw value (a registry name, or an instance to pass through). + getter (Callable[[], _NamedInstanceRegistry]): Returns the registry singleton. + owner (str): The owning class name, for error messages. + name (str): The parameter name, for error messages. + + Returns: + Any: The resolved instance. + + Raises: + ValueError: If the name is not registered. + """ + if not isinstance(value, str): + return value + + registry = getter() + instance = registry.get_instance_by_name(value) + if instance is not None: + return instance + + registry_label = type(registry).__name__ + available_names = registry.get_names() + if not available_names: + raise ValueError( + f"{owner}.{name}: '{value}' not found. The {registry_label} is empty. " + "Make sure to register instances (e.g. via an initializer) before building " + "components that reference them by name." + ) + raise ValueError( + f"{owner}.{name}: '{value}' not found in {registry_label}. Available: {', '.join(available_names)}" + ) + + +def resolve_constructor_args(*, cls: type, raw_args: dict[str, Any]) -> dict[str, Any]: + """ + Resolve a flat argument dict into constructor-ready keyword arguments. + + For each argument: validate it is a real constructor parameter (unless the + constructor accepts ``**kwargs``); resolve registry-reference parameters by + name; coerce simple string values to their annotated scalar type; pass + everything else through unchanged. + + Args: + cls (type): The class whose ``__init__`` signature drives resolution. + raw_args (dict[str, Any]): The raw argument values (e.g. from a form or agent). + + Returns: + dict[str, Any]: Arguments ready to pass to ``cls(**resolved)``. + + Raises: + ValueError: If the signature cannot be inspected, an argument is not a + valid constructor parameter, a registry reference cannot be resolved, + or a simple value cannot be coerced. + """ + try: + sig = inspect.signature(cls.__init__) + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to inspect __init__ signature for '{cls.__name__}': {e}") from e + + accepts_var_kwargs = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + valid_params = { + param_name: p + for param_name, p in sig.parameters.items() + if param_name != "self" and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + } + + resolved: dict[str, Any] = {} + for name, value in raw_args.items(): + param = valid_params.get(name) + if param is None and not accepts_var_kwargs: + raise ValueError( + f"Unknown parameter '{name}' for '{cls.__name__}'. Valid parameters: {sorted(valid_params.keys())}" + ) + + annotation = param.annotation if param is not None else inspect.Parameter.empty + + registry_getter = get_resolvable_registry_getter(annotation) + if registry_getter is not None: + resolved[name] = _resolve_registry_reference( + value=value, getter=registry_getter, owner=cls.__name__, name=name + ) + elif isinstance(value, str) and is_coercible_from_string(annotation): + try: + resolved[name] = coerce_string_to_annotation(value=value, annotation=annotation) + except (ValueError, TypeError) as e: + raise ValueError(f"Parameter '{name}' of '{cls.__name__}': {e}") from e + else: + resolved[name] = value + + return resolved diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 6b41ef93af..a0a6331b0d 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -14,20 +14,17 @@ ConverterPreviewRequest, CreateConverterRequest, ) -from pyrit.backend.services.converter_service import ConverterService, _is_llm_based, get_converter_service +from pyrit.backend.services.converter_service import ( + ConverterService, + _serialize_type, + get_converter_service, +) from pyrit.models import ComponentIdentifier from pyrit.prompt_converter import ( Base64Converter, CaesarConverter, - LLMGenericTextConverter, - NoiseConverter, - PersuasionConverter, RepeatTokenConverter, SuffixAppendConverter, - TenseConverter, - ToneConverter, - TranslationConverter, - VariationConverter, ) from pyrit.prompt_converter.prompt_converter import get_converter_modalities from pyrit.registry.object_registries import ConverterRegistry @@ -35,7 +32,7 @@ @pytest.fixture(autouse=True) def reset_registry(): - """Reset the ConverterRegistry singleton before each test.""" + """Reset the converter registry before each test.""" ConverterRegistry.reset_instance() yield ConverterRegistry.reset_instance() @@ -104,6 +101,60 @@ async def test_list_converter_catalog_includes_supported_types(self) -> None: assert "text" in base64_entry.supported_input_types assert "text" in base64_entry.supported_output_types + async def test_catalog_includes_all_constructible_converters(self) -> None: + """The catalog surfaces every constructible converter, including base/helper classes. + + Whether to display a given converter is left to the caller (e.g. the frontend), + so the service no longer hides anything. + """ + service = ConverterService() + + result = await service.list_converter_catalog_async() + + converter_types = [item.converter_type for item in result.items] + assert "Base64Converter" in converter_types + assert "SelectiveTextConverter" in converter_types + + async def test_catalog_serializes_parameter_type(self) -> None: + """Catalog renders the raw annotation into a human-readable type_name.""" + service = ConverterService() + + result = await service.list_converter_catalog_async() + + caesar_entry = next(item for item in result.items if item.converter_type == "CaesarConverter") + caesar_param = next(p for p in caesar_entry.parameters if p.name == "caesar_offset") + assert caesar_param.type_name == "int" + + async def test_catalog_excludes_non_coercible_params(self) -> None: + """Catalog only surfaces params that can be set from a string (e.g. not the LLM target).""" + service = ConverterService() + + result = await service.list_converter_catalog_async() + + persuasion_entry = next(item for item in result.items if item.converter_type == "PersuasionConverter") + assert persuasion_entry.is_llm_based is True + assert all("Target" not in p.type_name for p in persuasion_entry.parameters) + + +class TestSerializeType: + """Tests for the _serialize_type presentation helper.""" + + def test_empty_annotation(self) -> None: + import inspect + + assert _serialize_type(inspect.Parameter.empty) == "Any" + + def test_plain_type(self) -> None: + assert _serialize_type(int) == "int" + + def test_optional_pep604(self) -> None: + assert _serialize_type(str | None) == "Optional[str]" + + def test_literal(self) -> None: + from typing import Literal + + assert _serialize_type(Literal["a", "b"]) == "Literal['a', 'b']" + class TestGetConverter: """Tests for ConverterService.get_converter method.""" @@ -607,25 +658,3 @@ def test_base64_converter_default_params(self) -> None: # Verify type info is populated from identifier assert isinstance(result.supported_input_types, list) assert isinstance(result.supported_output_types, list) - - -class TestIsLlmBased: - """Tests for the _is_llm_based introspection helper""" - - def test_detects_llm_text_converter(self) -> None: - # Test that _is_llm_based correctly identifies converters that use LLMS as LLM-based. - for cls in ( - LLMGenericTextConverter, - NoiseConverter, - PersuasionConverter, - ToneConverter, - TenseConverter, - TranslationConverter, - VariationConverter, - ): - assert _is_llm_based(cls) is True, f"{cls.__name__} should be detected as LLM-based" - - def test_does_not_flag_non_target_converters(self) -> None: - # Test that _is_llm_based does not incorrectly flag non-LLM converters. - assert _is_llm_based(Base64Converter) is False - assert _is_llm_based(CaesarConverter) is False diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 0a7463d1f0..15116a0ac9 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -110,6 +110,7 @@ def mock_all_registries(mock_memory): mock_sr = MagicMock() mock_sr.get_class.return_value = mock_scenario_class + mock_sr.create_instance.return_value = mock_scenario_instance mock_tr = MagicMock() mock_tr.get_instance_by_name.return_value = MagicMock() @@ -452,23 +453,25 @@ async def test_start_run_runs_initializers(self, mock_all_registries) -> None: assert mock_init_instance.initialize_async.await_count == 2 async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: - """Test that scenario_result_id is passed to the scenario constructor for resumption.""" + """Test that scenario_result_id is passed to the registry constructor for resumption.""" service = ScenarioRunService() - mock_scenario_class = mock_all_registries["scenario_class"] + mock_sr = mock_all_registries["scenario_registry"] response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) assert response.status == ScenarioRunStatus.IN_PROGRESS - mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") + mock_sr.create_instance.assert_called_once_with( + "foundry.red_team_agent", scenario_result_id="existing-result-uuid" + ) async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: - """Test that scenario_result_id is not passed to constructor when not provided.""" + """Test that scenario_result_id is not passed to the registry constructor when not provided.""" service = ScenarioRunService() - mock_scenario_class = mock_all_registries["scenario_class"] + mock_sr = mock_all_registries["scenario_registry"] await service.start_run_async(request=_make_request()) - mock_scenario_class.assert_called_once_with() + mock_sr.create_instance.assert_called_once_with("foundry.red_team_agent") class TestScenarioRunServiceGetRun: diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index 7fa2de4599..37bba2585e 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -1,9 +1,61 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.models import ComponentIdentifier, PromptDataType -from pyrit.prompt_converter import ConverterResult, PromptConverter -from pyrit.registry.object_registries.converter_registry import ConverterRegistry +""" +Tests for the merged ``ConverterRegistry`` (buildable catalog + instance container) +and its introspection helpers. +""" + +from typing import Literal + +import pytest + +from pyrit.models import ComponentIdentifier, Message, MessagePiece, PromptDataType +from pyrit.prompt_converter import ( + Base64Converter, + CaesarConverter, + ConverterResult, + LLMGenericTextConverter, + NoiseConverter, + PersuasionConverter, + PromptConverter, + TenseConverter, + ToneConverter, + TranslationConverter, + VariationConverter, +) +from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration +from pyrit.registry.object_registries import ( + ConverterMetadata, + ConverterRegistry, + TargetRegistry, +) +from pyrit.registry.object_registries.converter_registry import ( + _extract_parameters, + _requires_llm_target, +) + + +class MockPromptTarget(PromptTarget): + """Minimal PromptTarget (with LLM-converter capabilities) for resolution tests.""" + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ) + ) + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + pass class MockTextConverter(PromptConverter): @@ -15,10 +67,6 @@ class MockTextConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "text". - Returns: ConverterResult: The unchanged prompt. """ @@ -34,10 +82,6 @@ class MockImageConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "image_path") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "image_path". - Returns: ConverterResult: The unchanged prompt. """ @@ -53,297 +97,359 @@ class MockMultiModalConverter(PromptConverter): async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """Convert prompt (no-op for testing). - Args: - prompt (str): The prompt to convert. - input_type (PromptDataType): The input type. Defaults to "text". - Returns: ConverterResult: The unchanged prompt. """ return ConverterResult(output_text=prompt, output_type="text") +@pytest.fixture +def registry(): + """Provide a fresh ``ConverterRegistry`` singleton, reset around each test.""" + ConverterRegistry.reset_instance() + instance = ConverterRegistry.get_registry_singleton() + yield instance + ConverterRegistry.reset_instance() + + +# --------------------------------------------------------------------------- +# Instance container (the registry's primary surface) +# --------------------------------------------------------------------------- + + class TestConverterRegistrySingleton: """Tests for the singleton pattern in ConverterRegistry.""" def setup_method(self): - """Reset the singleton before each test.""" ConverterRegistry.reset_instance() def teardown_method(self): - """Reset the singleton after each test.""" ConverterRegistry.reset_instance() def test_get_registry_singleton_returns_same_instance(self): - """Test that get_registry_singleton returns the same singleton each time.""" - instance1 = ConverterRegistry.get_registry_singleton() - instance2 = ConverterRegistry.get_registry_singleton() - - assert instance1 is instance2 + assert ConverterRegistry.get_registry_singleton() is ConverterRegistry.get_registry_singleton() def test_get_registry_singleton_returns_converter_registry_type(self): - """Test that get_registry_singleton returns a ConverterRegistry instance.""" - instance = ConverterRegistry.get_registry_singleton() - assert isinstance(instance, ConverterRegistry) + assert isinstance(ConverterRegistry.get_registry_singleton(), ConverterRegistry) def test_reset_instance_clears_singleton(self): - """Test that reset_instance clears the singleton.""" instance1 = ConverterRegistry.get_registry_singleton() ConverterRegistry.reset_instance() - instance2 = ConverterRegistry.get_registry_singleton() - - assert instance1 is not instance2 + assert ConverterRegistry.get_registry_singleton() is not instance1 class TestConverterRegistryRegisterInstance: """Tests for register_instance functionality in ConverterRegistry.""" - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() - - def test_register_instance_with_custom_name(self): - """Test registering a converter with a custom name.""" + def test_register_instance_with_custom_name(self, registry: ConverterRegistry): converter = MockTextConverter() - self.registry.register_instance(converter, name="custom_converter") + registry.register_instance(converter, name="custom_converter") - assert "custom_converter" in self.registry - assert self.registry.get("custom_converter") is converter + assert "custom_converter" in registry + assert registry.get_instance_by_name("custom_converter") is converter - def test_register_instance_generates_name_from_class(self): - """Test that register_instance generates a name from class name when not provided.""" + def test_register_instance_generates_name_from_class(self, registry: ConverterRegistry): converter = MockTextConverter() - self.registry.register_instance(converter) + registry.register_instance(converter) - # Name should be derived from class name with hash suffix - names = self.registry.get_names() + names = registry.get_names() assert len(names) == 1 assert names[0].startswith("MockTextConverter::") - def test_register_instance_multiple_converters_unique_names(self): - """Test registering multiple converters generates unique names.""" + def test_register_instance_multiple_converters_unique_names(self, registry: ConverterRegistry): + registry.register_instance(MockTextConverter()) + registry.register_instance(MockImageConverter()) + + assert len(registry) == 2 + + def test_register_instance_duplicate_name_overwrites(self, registry: ConverterRegistry): converter1 = MockTextConverter() converter2 = MockImageConverter() - self.registry.register_instance(converter1) - self.registry.register_instance(converter2) + registry.register_instance(converter1, name="shared_name") + registry.register_instance(converter2, name="shared_name") - assert len(self.registry) == 2 + assert len(registry) == 1 + assert registry.get_instance_by_name("shared_name") is converter2 - def test_register_instance_same_converter_type_different_names(self): - """Test that same converter class can be registered with different names.""" - converter1 = MockTextConverter() - converter2 = MockTextConverter() - self.registry.register_instance(converter1, name="converter_1") - self.registry.register_instance(converter2, name="converter_2") +class TestConverterRegistryGetInstanceByName: + """Tests for get_instance_by_name functionality in ConverterRegistry.""" - assert len(self.registry) == 2 + def test_get_instance_by_name_returns_converter(self, registry: ConverterRegistry): + converter = MockTextConverter() + registry.register_instance(converter, name="test_converter") + assert registry.get_instance_by_name("test_converter") is converter - def test_register_instance_duplicate_name_overwrites(self): - """Test that registering with a duplicate name silently overwrites the previous instance.""" - converter1 = MockTextConverter() - converter2 = MockImageConverter() + def test_get_instance_by_name_nonexistent_returns_none(self, registry: ConverterRegistry): + assert registry.get_instance_by_name("nonexistent") is None - self.registry.register_instance(converter1, name="shared_name") - self.registry.register_instance(converter2, name="shared_name") - assert len(self.registry) == 1 - assert self.registry.get("shared_name") is converter2 +class TestConverterRegistryInstanceMetadata: + """Tests for instance-level metadata (list_metadata is the container surface).""" + def test_instance_metadata_is_component_identifier(self, registry: ConverterRegistry): + converter = MockTextConverter() + registry.register_instance(converter, name="text_converter") -class TestConverterRegistryGetInstanceByName: - """Tests for get_instance_by_name functionality in ConverterRegistry.""" + metadata = registry.list_metadata() + assert len(metadata) == 1 + assert isinstance(metadata[0], ComponentIdentifier) + assert metadata[0] == converter.get_identifier() - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - self.converter = MockTextConverter() - self.registry.register_instance(self.converter, name="test_converter") + def test_instance_metadata_filter_by_class_name(self, registry: ConverterRegistry): + registry.register_instance(MockTextConverter(), name="t1") + registry.register_instance(MockTextConverter(), name="t2") + registry.register_instance(MockImageConverter(), name="i1") - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + metadata = registry.list_metadata(include_filters={"class_name": "MockTextConverter"}) + assert len(metadata) == 2 + assert all(m.class_name == "MockTextConverter" for m in metadata) - def test_get_instance_by_name_returns_converter(self): - """Test getting a registered converter by name.""" - result = self.registry.get_instance_by_name("test_converter") - assert result is self.converter - def test_get_instance_by_name_nonexistent_returns_none(self): - """Test that getting a non-existent converter returns None.""" - result = self.registry.get_instance_by_name("nonexistent") - assert result is None +class TestConverterRegistryContainerProtocol: + """Tests for the instance-primary protocol surface.""" + def test_contains_and_len_and_iter(self, registry: ConverterRegistry): + registry.register_instance(MockTextConverter(), name="test_converter") + assert "test_converter" in registry + assert "unknown_converter" not in registry + assert len(registry) == 1 + assert "test_converter" in list(registry) -class TestConverterRegistryBuildMetadata: - """Tests for _build_metadata functionality in ConverterRegistry.""" + def test_get_names_returns_sorted_list(self, registry: ConverterRegistry): + registry.register_instance(MockImageConverter(), name="zeta_converter") + registry.register_instance(MockImageConverter(), name="alpha_converter") + assert registry.get_names() == ["alpha_converter", "zeta_converter"] - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() + def test_get_all_instances_returns_all(self, registry: ConverterRegistry): + text = MockTextConverter() + image = MockImageConverter() + registry.register_instance(text, name="text_converter") + registry.register_instance(image, name="image_converter") - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + entry_map = {e.name: e for e in registry.get_all_instances()} + assert entry_map["text_converter"].instance is text + assert entry_map["image_converter"].instance is image - def test_build_metadata_includes_class_name(self): - """Test that metadata includes the converter class name.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") - metadata = self.registry.list_metadata() - assert len(metadata) == 1 - assert metadata[0].class_name == "MockTextConverter" +# --------------------------------------------------------------------------- +# Buildable class catalog (discovery + introspection + build) +# --------------------------------------------------------------------------- - def test_build_metadata_includes_supported_input_types(self): - """Test that metadata includes supported_input_types in params.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("text",) +class TestDiscovery: + """Tests for converter class discovery.""" - def test_build_metadata_includes_supported_output_types(self): - """Test that metadata includes supported_output_types in params.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") + def test_discovers_known_converters(self, registry: ConverterRegistry): + names = registry.get_class_names() + assert "Base64Converter" in names + assert "CaesarConverter" in names - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_output_types"] == ("text",) + def test_discovers_non_catalog_converters(self, registry: ConverterRegistry): + # SelectiveTextConverter is hidden from the user-facing catalog (a frontend + # concern) but must remain discoverable/buildable so agents can use it. + assert "SelectiveTextConverter" in registry.get_class_names() - def test_build_metadata_is_component_identifier(self): - """Test that metadata is the converter's ComponentIdentifier.""" - converter = MockTextConverter() - self.registry.register_instance(converter, name="text_converter") + def test_does_not_register_base_class(self, registry: ConverterRegistry): + assert "PromptConverter" not in registry.get_class_names() - metadata = self.registry.list_metadata() - assert isinstance(metadata[0], ComponentIdentifier) - assert metadata[0] == converter.get_identifier() + def test_keyed_by_exact_class_name(self, registry: ConverterRegistry): + names = registry.get_class_names() + assert "Base64Converter" in names + assert "base64_converter" not in names - def test_build_metadata_different_modalities(self): - """Test that metadata reflects converter-specific modalities.""" - converter = MockImageConverter() - self.registry.register_instance(converter, name="image_converter") - metadata = self.registry.list_metadata() - assert metadata[0].params["supported_input_types"] == ("image_path",) - assert metadata[0].params["supported_output_types"] == ("text",) - assert metadata[0].class_name == "MockImageConverter" +class TestGetClass: + """Tests for get_class (the inherited class-catalog accessor).""" + def test_returns_class(self, registry: ConverterRegistry): + assert registry.get_class("Base64Converter") is Base64Converter -class TestConverterRegistryListMetadataFiltering: - """Tests for list_metadata filtering in ConverterRegistry.""" + def test_unknown_type_raises(self, registry: ConverterRegistry): + with pytest.raises(KeyError, match="not found"): + registry.get_class("NotARealConverter") - def setup_method(self): - """Reset and get a fresh registry with multiple converters.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() + def test_is_subclass_relationship(self, registry: ConverterRegistry): + assert issubclass(registry.get_class("Base64Converter"), PromptConverter) - self.text_converter1 = MockTextConverter() - self.text_converter2 = MockTextConverter() - self.image_converter = MockImageConverter() - self.multi_modal_converter = MockMultiModalConverter() - self.registry.register_instance(self.text_converter1, name="text_converter_1") - self.registry.register_instance(self.text_converter2, name="text_converter_2") - self.registry.register_instance(self.image_converter, name="image_converter") - self.registry.register_instance(self.multi_modal_converter, name="multi_modal_converter") +class TestCreateInstance: + """Tests for create_instance (build via the shared resolver).""" - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() + def test_creates_instance(self, registry: ConverterRegistry): + assert isinstance(registry.create_instance("Base64Converter"), Base64Converter) - def test_list_metadata_no_filter_returns_all(self): - """Test that list_metadata without filters returns all items.""" - metadata = self.registry.list_metadata() - assert len(metadata) == 4 + def test_coerces_string_params(self, registry: ConverterRegistry): + converter = registry.create_instance("CaesarConverter", caesar_offset="13") + assert isinstance(converter, CaesarConverter) + assert converter.get_identifier().params.get("caesar_offset") == 13 - def test_list_metadata_filter_by_class_name(self): - """Test filtering metadata by class_name.""" - metadata = self.registry.list_metadata(include_filters={"class_name": "MockTextConverter"}) - assert len(metadata) == 2 - assert all(m.class_name == "MockTextConverter" for m in metadata) + def test_unknown_type_raises(self, registry: ConverterRegistry): + with pytest.raises(KeyError, match="not found"): + registry.create_instance("NotARealConverter") - def test_list_metadata_filter_by_supported_input_type(self): - """Test filtering metadata by supported_input_types (containment check).""" - # "text" is in supported_input_types for MockTextConverter and MockMultiModalConverter - metadata = self.registry.list_metadata(include_filters={"supported_input_types": "text"}) - assert len(metadata) == 3 # 2 text converters + 1 multi-modal - class_names = {m.class_name for m in metadata} - assert "MockTextConverter" in class_names - assert "MockMultiModalConverter" in class_names - - def test_list_metadata_exclude_by_class_name(self): - """Test excluding metadata by class_name.""" - metadata = self.registry.list_metadata(exclude_filters={"class_name": "MockTextConverter"}) - assert len(metadata) == 2 - assert all(m.class_name != "MockTextConverter" for m in metadata) - - def test_list_metadata_combined_include_and_exclude(self): - """Test combined include and exclude filters.""" - # Include converters that accept text, exclude MockMultiModalConverter - metadata = self.registry.list_metadata( - include_filters={"supported_input_types": "text"}, - exclude_filters={"class_name": "MockMultiModalConverter"}, + def test_unknown_param_raises(self, registry: ConverterRegistry): + with pytest.raises(ValueError, match="Unknown parameter"): + registry.create_instance("Base64Converter", not_a_param="x") + + def test_build_does_not_register_instance(self, registry: ConverterRegistry): + registry.create_instance("Base64Converter") + assert len(registry) == 0 + + def test_honors_registered_default_kwargs(self, registry: ConverterRegistry): + registry.register(CaesarConverter, name="CaesarDefault", default_kwargs={"caesar_offset": 5}) + converter = registry.create_instance("CaesarDefault") + assert converter.get_identifier().params.get("caesar_offset") == 5 + + def test_uses_registered_factory(self, registry: ConverterRegistry): + sentinel = Base64Converter() + registry.register(Base64Converter, name="B64Factory", factory=lambda **kwargs: sentinel) + assert registry.create_instance("B64Factory") is sentinel + + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateLLMConverter: + """Tests that LLM converters are buildable by resolving a target by name.""" + + def test_build_llm_converter_resolves_target_by_name(self, registry: ConverterRegistry): + target = MockPromptTarget() + TargetRegistry.reset_instance() + TargetRegistry.get_registry_singleton().register_instance(target, name="my_target") + try: + converter = registry.create_instance("TenseConverter", converter_target="my_target", tense="past") + assert isinstance(converter, TenseConverter) + assert converter._converter_target is target + finally: + TargetRegistry.reset_instance() + + def test_build_llm_converter_unknown_target_raises(self, registry: ConverterRegistry): + TargetRegistry.reset_instance() + try: + with pytest.raises(ValueError, match="not found"): + registry.create_instance("TenseConverter", converter_target="missing", tense="past") + finally: + TargetRegistry.reset_instance() + + +class TestClassMetadata: + """Tests for converter class-catalog metadata building.""" + + def _metadata_for(self, registry: ConverterRegistry, name: str) -> ConverterMetadata: + return next(m for m in registry.list_class_metadata() if m.class_name == name) + + def test_metadata_includes_supported_types(self, registry: ConverterRegistry): + meta = self._metadata_for(registry, "Base64Converter") + assert "text" in meta.supported_input_types + assert "text" in meta.supported_output_types + + def test_metadata_has_no_catalog_visible_field(self, registry: ConverterRegistry): + # catalog_visible is a presentation concern owned by the backend/frontend. + assert not hasattr(self._metadata_for(registry, "Base64Converter"), "catalog_visible") + + def test_is_llm_based_flag(self, registry: ConverterRegistry): + llm_based = ( + LLMGenericTextConverter, + NoiseConverter, + PersuasionConverter, + ToneConverter, + TenseConverter, + TranslationConverter, + VariationConverter, ) - assert len(metadata) == 2 - assert all(m.class_name == "MockTextConverter" for m in metadata) + for cls in llm_based: + meta = self._metadata_for(registry, cls.__name__) + assert meta.is_llm_based is True, f"{cls.__name__} should be LLM-based" + assert self._metadata_for(registry, "Base64Converter").is_llm_based is False + assert self._metadata_for(registry, "CaesarConverter").is_llm_based is False + def test_parameters_extracted(self, registry: ConverterRegistry): + meta = self._metadata_for(registry, "CaesarConverter") + caesar_param = next(p for p in meta.parameters if p.name == "caesar_offset") + assert caesar_param.required is True + assert caesar_param.annotation is int + assert caesar_param.coercible_from_string is True -class TestConverterRegistryInheritedMethods: - """Tests for inherited methods from RetrievableInstanceRegistry.""" + def test_surfaces_non_coercible_params(self, registry: ConverterRegistry): + # An LLM-based converter exposes its target parameter for dynamic + # construction even though it cannot be coerced from a string. + meta = self._metadata_for(registry, "PersuasionConverter") + non_coercible = [p for p in meta.parameters if not p.coercible_from_string] + assert non_coercible, "expected at least one non-coercible parameter (the LLM target)" - def setup_method(self): - """Reset and get a fresh registry.""" - ConverterRegistry.reset_instance() - self.registry = ConverterRegistry.get_registry_singleton() - self.converter = MockTextConverter() - self.registry.register_instance(self.converter, name="test_converter") - def teardown_method(self): - """Reset the singleton after each test.""" - ConverterRegistry.reset_instance() +# --------------------------------------------------------------------------- +# Introspection helpers +# --------------------------------------------------------------------------- + + +class _UnionTargetConverter: + """Helper with a PEP 604 unioned target parameter for introspection tests.""" + + def __init__(self, *, target: PromptTarget | None = None, offset: int | None = None) -> None: + self.target = target + self.offset = offset + + +class _OptionalLiteralConverter: + """Helper with an optional Literal parameter for choices extraction tests.""" + + def __init__(self, *, fmt: Literal["A", "B"] | None = None) -> None: + self.fmt = fmt + + +class TestExtractParameters: + """Tests for the converter-parameter introspection helper.""" + + def test_exposes_raw_annotation(self) -> None: + offset_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "offset") + assert offset_param.annotation == (int | None) + assert offset_param.coercible_from_string is True + + def test_includes_non_coercible(self) -> None: + target_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "target") + assert target_param.coercible_from_string is False + + def test_optional_literal_choices(self) -> None: + fmt_param = next(p for p in _extract_parameters(_OptionalLiteralConverter) if p.name == "fmt") + assert fmt_param.choices == ("A", "B") + + def test_sets_requires_llm(self) -> None: + params = _extract_parameters(_UnionTargetConverter) + target_param = next(p for p in params if p.name == "target") + offset_param = next(p for p in params if p.name == "offset") + assert target_param.requires_llm is True + assert offset_param.requires_llm is False + + +class TestRequiresLlmTarget: + """Tests for the _requires_llm_target helper.""" + + def test_plain_target(self) -> None: + assert _requires_llm_target(PromptTarget) is True + + def test_optional_target(self) -> None: + assert _requires_llm_target(PromptTarget | None) is True + + def test_non_target(self) -> None: + assert _requires_llm_target(int) is False + assert _requires_llm_target(str | None) is False + + +class TestNoBackendDependency: + """The registry must be reusable without depending on pyrit.backend.""" + + def test_module_has_no_backend_dependency(self) -> None: + import ast + import inspect + + import pyrit.registry.object_registries.converter_registry as module - def test_contains_registered_name(self): - """Test __contains__ for registered name.""" - assert "test_converter" in self.registry - - def test_contains_unregistered_name(self): - """Test __contains__ for unregistered name.""" - assert "unknown_converter" not in self.registry - - def test_len_returns_count(self): - """Test __len__ returns correct count.""" - assert len(self.registry) == 1 - - def test_iter_yields_names(self): - """Test __iter__ yields registered names.""" - names = list(self.registry) - assert "test_converter" in names - - def test_get_names_returns_sorted_list(self): - """Test get_names returns sorted list of names.""" - self.registry.register_instance(MockImageConverter(), name="alpha_converter") - self.registry.register_instance(MockImageConverter(), name="zeta_converter") - - names = self.registry.get_names() - assert names == ["alpha_converter", "test_converter", "zeta_converter"] - - def test_get_all_instances_returns_all(self): - """Test get_all_instances returns list of all registered entries.""" - image_converter = MockImageConverter() - self.registry.register_instance(image_converter, name="image_converter") - - all_entries = self.registry.get_all_instances() - assert len(all_entries) == 2 - entry_map = {e.name: e for e in all_entries} - assert entry_map["test_converter"].instance is self.converter - assert entry_map["image_converter"].instance is image_converter + tree = ast.parse(inspect.getsource(module)) + imported_modules: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + assert not any(name.startswith("pyrit.backend") for name in imported_modules) diff --git a/tests/unit/registry/test_resolution.py b/tests/unit/registry/test_resolution.py new file mode 100644 index 0000000000..8a7889bd78 --- /dev/null +++ b/tests/unit/registry/test_resolution.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the shared registry constructor-argument resolution primitive. +""" + +from typing import Literal + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import PromptTarget +from pyrit.registry.object_registries import TargetRegistry +from pyrit.registry.resolution import ( + coerce_string_to_annotation, + get_resolvable_registry_getter, + get_union_non_none_args, + is_coercible_from_string, + is_registry_reference, + resolve_constructor_args, +) + + +class MockPromptTarget(PromptTarget): + """Minimal PromptTarget for registry-resolution tests.""" + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return [MessagePiece(role="assistant", original_value="mock response").to_message()] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + pass + + +class _NeedsTarget: + """Helper whose constructor takes a registry-reference target plus simple params.""" + + def __init__(self, *, converter_target: PromptTarget, offset: int = 0, label: str = "x") -> None: + self.converter_target = converter_target + self.offset = offset + self.label = label + + +class _SimpleOnly: + """Helper whose constructor takes only simple/coercible params.""" + + def __init__( + self, *, count: int = 1, ratio: float = 0.5, flag: bool = False, mode: Literal["a", "b"] = "a" + ) -> None: + self.count = count + self.ratio = ratio + self.flag = flag + self.mode = mode + + +class _AcceptsKwargs: + """Helper whose constructor accepts arbitrary keyword arguments.""" + + def __init__(self, *, name: str = "n", **kwargs: object) -> None: + self.name = name + self.kwargs = kwargs + + +@pytest.fixture +def target_registry(): + """Provide a fresh TargetRegistry singleton with one registered target.""" + TargetRegistry.reset_instance() + registry = TargetRegistry.get_registry_singleton() + registry.register_instance(MockPromptTarget(), name="my_target") + yield registry + TargetRegistry.reset_instance() + + +@pytest.fixture +def empty_target_registry(): + """Provide a fresh, empty TargetRegistry singleton.""" + TargetRegistry.reset_instance() + registry = TargetRegistry.get_registry_singleton() + yield registry + TargetRegistry.reset_instance() + + +class TestTypeHelpers: + """Tests for the type-introspection helpers.""" + + def test_get_union_non_none_args_pep604(self) -> None: + assert get_union_non_none_args(int | None) == [int] + + def test_get_union_non_none_args_not_a_union(self) -> None: + assert get_union_non_none_args(int) is None + + def test_is_coercible_from_string(self) -> None: + assert is_coercible_from_string(str) is True + assert is_coercible_from_string(int | None) is True + assert is_coercible_from_string(Literal["a", "b"]) is True + assert is_coercible_from_string(PromptTarget) is False + + def test_is_registry_reference(self) -> None: + assert is_registry_reference(PromptTarget) is True + assert is_registry_reference(PromptTarget | None) is True + assert is_registry_reference(int) is False + + def test_get_resolvable_registry_getter_returns_target_registry(self) -> None: + getter = get_resolvable_registry_getter(PromptTarget) + assert getter is not None + assert isinstance(getter(), TargetRegistry) + + def test_get_resolvable_registry_getter_none_for_simple(self) -> None: + assert get_resolvable_registry_getter(int) is None + + +class TestCoerceStringToAnnotation: + """Tests for scalar string coercion.""" + + def test_int(self) -> None: + assert coerce_string_to_annotation(value="42", annotation=int) == 42 + + def test_float(self) -> None: + assert coerce_string_to_annotation(value="0.25", annotation=float) == 0.25 + + def test_bool_true(self) -> None: + assert coerce_string_to_annotation(value="yes", annotation=bool) is True + + def test_bool_false(self) -> None: + assert coerce_string_to_annotation(value="0", annotation=bool) is False + + def test_bool_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="boolean"): + coerce_string_to_annotation(value="maybe", annotation=bool) + + def test_optional_unwrapped(self) -> None: + assert coerce_string_to_annotation(value="7", annotation=int | None) == 7 + + def test_str_passthrough(self) -> None: + assert coerce_string_to_annotation(value="hello", annotation=str) == "hello" + + def test_literal_valid(self) -> None: + assert coerce_string_to_annotation(value="b", annotation=Literal["a", "b"]) == "b" + + def test_literal_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="one of"): + coerce_string_to_annotation(value="c", annotation=Literal["a", "b"]) + + def test_literal_coerces_to_member_type(self) -> None: + result = coerce_string_to_annotation(value="2", annotation=Literal[1, 2]) + assert result == 2 + assert isinstance(result, int) + + +@pytest.mark.usefixtures("patch_central_database") +class TestResolveConstructorArgs: + """Tests for the end-to-end resolve_constructor_args.""" + + def test_coerces_simple_params(self) -> None: + resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "3", "ratio": "0.75", "flag": "true"}) + assert resolved == {"count": 3, "ratio": 0.75, "flag": True} + + def test_literal_passthrough(self) -> None: + resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "b"}) + assert resolved == {"mode": "b"} + + def test_literal_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="mode"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "z"}) + + def test_unknown_param_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown parameter 'nope'"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) + + def test_unknown_param_lists_valid_params(self) -> None: + with pytest.raises(ValueError, match="count"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) + + def test_var_kwargs_accepts_unknown(self) -> None: + resolved = resolve_constructor_args(cls=_AcceptsKwargs, raw_args={"anything": "value"}) + assert resolved == {"anything": "value"} + + def test_invalid_coercion_raises(self) -> None: + with pytest.raises(ValueError, match="count"): + resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "not-an-int"}) + + def test_resolves_registry_reference_by_name(self, target_registry: TargetRegistry) -> None: + resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "my_target", "offset": "5"}) + assert resolved["converter_target"] is target_registry.get_instance_by_name("my_target") + assert resolved["offset"] == 5 + + def test_registry_reference_instance_passthrough(self, target_registry: TargetRegistry) -> None: + instance = MockPromptTarget() + resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": instance}) + assert resolved["converter_target"] is instance + + def test_unknown_registry_reference_raises_with_names(self, target_registry: TargetRegistry) -> None: + with pytest.raises(ValueError, match="my_target"): + resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + + def test_unknown_registry_reference_empty_registry_hint(self, empty_target_registry: TargetRegistry) -> None: + with pytest.raises(ValueError, match="is empty"): + resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + + +def test_module_has_no_backend_dependency() -> None: + # The resolution primitive must be reusable without depending on pyrit.backend. + import ast + import inspect + + import pyrit.registry.resolution as module + + tree = ast.parse(inspect.getsource(module)) + imported_modules: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + assert not any(name.startswith("pyrit.backend") for name in imported_modules)