From 59672f1acb142f9c7f1fcc0d75151dc6193fbed2 Mon Sep 17 00:00:00 2001 From: jbolor21 <86250273+jbolor21@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:24:02 -0700 Subject: [PATCH 1/3] adding inital POC draft for scorers in GUI WIP --- .../src/components/Chat/ChatWindow.test.tsx | 36 ++ frontend/src/components/Chat/ChatWindow.tsx | 60 ++- .../src/components/Chat/ConversationPanel.tsx | 39 +- frontend/src/components/Chat/MessageList.tsx | 56 ++- .../src/components/Chat/ScoreDialog.test.tsx | 255 +++++++++++ frontend/src/components/Chat/ScoreDialog.tsx | 395 ++++++++++++++++++ frontend/src/services/api.ts | 36 ++ frontend/src/types/index.ts | 41 ++ frontend/src/utils/messageMapper.ts | 5 + pyrit/backend/main.py | 3 + pyrit/backend/models/scoring.py | 85 ++++ pyrit/backend/routes/scoring.py | 152 +++++++ pyrit/backend/services/scoring_service.py | 273 ++++++++++++ tests/unit/backend/test_scoring_service.py | 356 ++++++++++++++++ 14 files changed, 1788 insertions(+), 4 deletions(-) create mode 100644 frontend/src/components/Chat/ScoreDialog.test.tsx create mode 100644 frontend/src/components/Chat/ScoreDialog.tsx create mode 100644 pyrit/backend/models/scoring.py create mode 100644 pyrit/backend/routes/scoring.py create mode 100644 pyrit/backend/services/scoring_service.py create mode 100644 tests/unit/backend/test_scoring_service.py diff --git a/frontend/src/components/Chat/ChatWindow.test.tsx b/frontend/src/components/Chat/ChatWindow.test.tsx index 357b15e832..18ccba31a4 100644 --- a/frontend/src/components/Chat/ChatWindow.test.tsx +++ b/frontend/src/components/Chat/ChatWindow.test.tsx @@ -32,6 +32,8 @@ jest.mock("../../services/api", () => ({ getConversations: jest.fn(), createConversation: jest.fn(), changeMainConversation: jest.fn(), + scoreConversation: jest.fn(), + scoreMessagePiece: jest.fn(), }, convertersApi: { listConverterCatalog: jest.fn(), @@ -40,6 +42,9 @@ jest.mock("../../services/api", () => ({ createConverter: jest.fn(), previewConversion: jest.fn(), }, + scorersApi: { + listScorers: jest.fn().mockResolvedValue({ items: [] }), + }, labelsApi: { getLabels: jest.fn().mockImplementation(() => new Promise(() => {})), }, @@ -2159,6 +2164,37 @@ describe("ChatWindow Integration", () => { expect(toggleBtn).toBe(screen.getByTestId("toggle-panel-btn")); }); + it("ribbon Score button is disabled until a conversation is active", () => { + render( + + + + ); + expect(screen.getByTestId("score-conversation-btn")).toBeDisabled(); + }); + + it("ribbon Score button opens the score dialog for the active conversation", async () => { + render( + + + + ); + + const scoreBtn = screen.getByTestId("score-conversation-btn"); + expect(scoreBtn).toBeEnabled(); + await userEvent.click(scoreBtn); + + // ScoreDialog mounts and fetches scorers (mock resolves to empty list). + await waitFor(() => { + expect(screen.getByTestId("score-dialog-empty")).toBeInTheDocument(); + }); + }); + it("should toggle converter panel when convert button is clicked", async () => { render( diff --git a/frontend/src/components/Chat/ChatWindow.tsx b/frontend/src/components/Chat/ChatWindow.tsx index 230678e220..ad12765f15 100644 --- a/frontend/src/components/Chat/ChatWindow.tsx +++ b/frontend/src/components/Chat/ChatWindow.tsx @@ -4,12 +4,13 @@ import { Text, Tooltip, } from '@fluentui/react-components' -import { AddRegular, PanelRightRegular } from '@fluentui/react-icons' +import { AddRegular, PanelRightRegular, ClipboardTaskRegular } from '@fluentui/react-icons' import MessageList from './MessageList' import ChatInputArea from './ChatInputArea' import ConversationPanel from './ConversationPanel' import ConverterPanel from './ConverterPanel' import TargetBadge from './TargetBadge' +import ScoreDialog, { type ScoreTarget } from './ScoreDialog' import type { PieceConversion } from './converterTypes' import { PIECE_TYPE_TO_DATA_TYPE, basenameFromValue, buildMediaUrl, dataTypeToAttachmentKind, isPathDataType } from './converterTypes' import LabelsBar from '../Labels/LabelsBar' @@ -74,6 +75,7 @@ export default function ChatWindow({ const [attachmentData, setAttachmentData] = useState>({}) const [pieceConversions, setPieceConversions] = useState>({}) const [panelRefreshKey, setPanelRefreshKey] = useState(0) + const [scoreTarget, setScoreTarget] = useState(null) const inputBoxRef = useRef(null) const handleAttachmentsChange = useCallback((types: string[], data: Record) => { @@ -485,6 +487,28 @@ export default function ChatWindow({ } }, [attackResultId]) + // Open the score dialog for a specific assistant message piece. + const handleScoreMessage = useCallback((messageIndex: number) => { + if (!attackResultId || !activeConversationId) return + const msg = messages[messageIndex] + if (!msg?.pieceId) return + setScoreTarget({ + kind: 'piece', + attackResultId, + conversationId: activeConversationId, + pieceId: msg.pieceId, + }) + }, [attackResultId, activeConversationId, messages]) + + // After any score completes, refetch messages so the new score badges appear + // and bump the conversation panel refresh so its scoreboard / count stays current. + const handleScored = useCallback(() => { + if (attackResultId && activeConversationId) { + loadConversation(attackResultId, activeConversationId) + } + setPanelRefreshKey(k => k + 1) + }, [attackResultId, activeConversationId, loadConversation]) + const singleTurnLimitReached = activeTarget?.capabilities?.supports_multi_turn === false && messages.some(m => m.role === 'user') // Operator locking: if the loaded attack's operator differs from the current @@ -564,6 +588,32 @@ export default function ChatWindow({ )}
+ + +
) } diff --git a/frontend/src/components/Chat/ConversationPanel.tsx b/frontend/src/components/Chat/ConversationPanel.tsx index 267b0feaf1..76567edb5d 100644 --- a/frontend/src/components/Chat/ConversationPanel.tsx +++ b/frontend/src/components/Chat/ConversationPanel.tsx @@ -17,11 +17,13 @@ import { DismissRegular, StarRegular, StarFilled, + ClipboardTaskRegular, } from '@fluentui/react-icons' import { attacksApi } from '../../services/api' import { toApiError } from '../../services/errors' -import type { ConversationSummary } from '../../types' +import type { BackendScore, ConversationSummary } from '../../types' import { useConversationPanelStyles } from './ConversationPanel.styles' +import ScoreDialog, { type ScoreTarget } from './ScoreDialog' interface ConversationPanelProps { attackResultId: string | null @@ -34,6 +36,8 @@ interface ConversationPanelProps { lockedReason?: string /** Increment to trigger a conversation list refresh (e.g. after sending a message) */ refreshKey?: number + /** Called after a conversation is scored so the parent can refetch messages. */ + onConversationScored?: (conversationId: string, scores: BackendScore[]) => void } export default function ConversationPanel({ @@ -45,12 +49,14 @@ export default function ConversationPanel({ onClose, lockedReason, refreshKey, + onConversationScored, }: ConversationPanelProps) { const styles = useConversationPanelStyles() const [conversations, setConversations] = useState([]) const [mainConversationId, setMainConversationId] = useState(null) const [isLoading, setIsLoading] = useState(false) const [error, setError] = useState(null) + const [scoreTarget, setScoreTarget] = useState(null) const fetchConversations = useCallback(async () => { if (!attackResultId) { @@ -202,6 +208,25 @@ export default function ConversationPanel({ style={{ minWidth: 'auto', padding: '2px' }} /> + + + + + + + + ) +} + +export type { ScoreTarget, ScoreDialogProps } diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index 3c04828cb0..c4fa9fc23b 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -20,6 +20,10 @@ import type { CreateConversationRequest, CreateConversationResponse, ChangeMainConversationResponse, + ScorerListResponse, + ScoreConversationRequest, + ScoreMessageRequest, + ScoreResponse, } from '../types' const API_BASE_URL = import.meta.env.VITE_API_URL || '/api' @@ -277,6 +281,38 @@ export const attacksApi = { const response = await apiClient.get('/attacks/converter-options') return response.data }, + + scoreConversation: async ( + attackResultId: string, + conversationId: string, + request: ScoreConversationRequest + ): Promise => { + const response = await apiClient.post( + `/attacks/${encodeURIComponent(attackResultId)}/conversations/${encodeURIComponent(conversationId)}/scores`, + request + ) + return response.data + }, + + scoreMessagePiece: async ( + attackResultId: string, + conversationId: string, + pieceId: string, + request: ScoreMessageRequest + ): Promise => { + const response = await apiClient.post( + `/attacks/${encodeURIComponent(attackResultId)}/conversations/${encodeURIComponent(conversationId)}/pieces/${encodeURIComponent(pieceId)}/scores`, + request + ) + return response.data + }, +} + +export const scorersApi = { + listScorers: async (): Promise => { + const response = await apiClient.get('/scorers') + return response.data + }, } export const labelsApi = { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1c6dcc283e..743cdcc8cd 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -38,6 +38,14 @@ export interface Message { originalContent?: string /** Original media attachments before conversion (when different from converted). */ originalAttachments?: MessageAttachment[] + /** + * Backend piece ID of the first piece in this message. Preserved so the + * GUI can target a specific piece (e.g. for per-message scoring) without + * extending Message to carry every individual piece's id. + */ + pieceId?: string + /** Aggregated scores across all pieces in this message. */ + scores?: BackendScore[] } export interface MessageError { @@ -273,3 +281,36 @@ export interface ChangeMainConversationResponse { attack_result_id: string conversation_id: string } + +// --- Scoring --- + +export type ScorerScoreType = 'true_false' | 'float_scale' | 'unknown' + +export interface ScorerSummary { + scorer_registry_name: string + scorer_type: string + score_type: ScorerScoreType + description?: string | null + tags?: string[] +} + +export interface ScorerListResponse { + items: ScorerSummary[] +} + +export type ScoreConversationMode = 'last_message' | 'whole_conversation' + +export interface ScoreConversationRequest { + scorer_registry_name: string + mode?: ScoreConversationMode + objective?: string +} + +export interface ScoreMessageRequest { + scorer_registry_name: string + objective?: string +} + +export interface ScoreResponse { + scores: BackendScore[] +} diff --git a/frontend/src/utils/messageMapper.ts b/frontend/src/utils/messageMapper.ts index 703aca0b4a..c4deea8c86 100644 --- a/frontend/src/utils/messageMapper.ts +++ b/frontend/src/utils/messageMapper.ts @@ -244,6 +244,11 @@ export function backendMessageToFrontend(msg: BackendMessage): Message { reasoningSummaries: reasoningSummaries.length > 0 ? reasoningSummaries : undefined, originalContent: hasTextDiff ? originalContent : undefined, originalAttachments: hasMediaDiff ? originalAttachments : undefined, + pieceId: msg.pieces[0]?.piece_id, + scores: (() => { + const allScores = msg.pieces.flatMap((p) => p.scores ?? []) + return allScores.length > 0 ? allScores : undefined + })(), } } diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index c2c2f477cf..76ea7db8bf 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -27,6 +27,7 @@ labels, media, scenarios, + scoring, targets, version, ) @@ -123,6 +124,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(auth.router, prefix="/api", tags=["auth"]) app.include_router(media.router, prefix="/api", tags=["media"]) +app.include_router(scoring.scorers_router, prefix="/api", tags=["scorers"]) +app.include_router(scoring.attack_scoring_router, prefix="/api", tags=["scorers"]) app.include_router(version.router, tags=["version"]) diff --git a/pyrit/backend/models/scoring.py b/pyrit/backend/models/scoring.py new file mode 100644 index 0000000000..55f7cf0514 --- /dev/null +++ b/pyrit/backend/models/scoring.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scoring request/response models. + +DTOs for the on-demand scoring surface exposed under ``/api/scorers`` and +``/api/attacks/{id}/conversations/{cid}/scores``. Distinct from the planned +read-only scorer-introspection surface (eval metrics, etc.) — this file only +covers the inputs and outputs needed to *invoke* a registered scorer. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +from pyrit.backend.models.attacks import Score + +__all__ = [ + "ScorerSummary", + "ScorerListResponse", + "ScoreConversationMode", + "ScoreConversationRequest", + "ScoreMessageRequest", + "ScoreResponse", +] + + +ScoreConversationMode = Literal["last_message", "whole_conversation"] + + +class ScorerSummary(BaseModel): + """Minimal scorer entry used to populate the scoring dialog.""" + + scorer_registry_name: str = Field(..., description="Registry name of the scorer instance") + scorer_type: str = Field(..., description="Scorer class name (e.g., 'SelfAskRefusalScorer')") + score_type: str = Field(..., description="Score shape: 'true_false', 'float_scale', or 'unknown'") + description: str | None = Field( + None, + description=( + "First paragraph of the scorer class docstring. Surfaces in the GUI as an info pane so users " + "can see what each scorer does without leaving the dialog." + ), + ) + tags: list[str] = Field( + default_factory=list, + description="Registry tags (e.g. 'refusal', 'best_refusal'). Used in the GUI for grouping/badges.", + ) + + +class ScorerListResponse(BaseModel): + """Response listing every registered scorer.""" + + items: list[ScorerSummary] = Field(..., description="Registered scorers in registry-name order") + + +class ScoreConversationRequest(BaseModel): + """Request to score a conversation with a registered scorer.""" + + scorer_registry_name: str = Field(..., description="Registry name of the scorer to invoke") + mode: ScoreConversationMode = Field( + "last_message", + description=( + "'last_message' scores only the most recent assistant message; " + "'whole_conversation' wraps the scorer in a ConversationScorer and scores the full transcript." + ), + ) + objective: str | None = Field( + None, description="Optional objective to pass to the scorer (only used by objective scorers)" + ) + + +class ScoreMessageRequest(BaseModel): + """Request to score a single message piece with a registered scorer.""" + + scorer_registry_name: str = Field(..., description="Registry name of the scorer to invoke") + objective: str | None = Field( + None, description="Optional objective to pass to the scorer (only used by objective scorers)" + ) + + +class ScoreResponse(BaseModel): + """Response containing the scores produced by an on-demand scoring call.""" + + scores: list[Score] = Field(default_factory=list, description="Scores produced by the scorer") diff --git a/pyrit/backend/routes/scoring.py b/pyrit/backend/routes/scoring.py new file mode 100644 index 0000000000..85887c4058 --- /dev/null +++ b/pyrit/backend/routes/scoring.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +On-demand scoring routes. + +Surfaces two related endpoints: + +* ``GET /scorers`` — minimal list of registered scorer instances for the GUI dropdown. +* ``POST /attacks/{attack_result_id}/conversations/{conversation_id}/scores`` — score + either the last assistant message in a conversation or the whole conversation + (the latter wraps the chosen scorer in a ``ConversationScorer``). +* ``POST /attacks/{attack_result_id}/conversations/{conversation_id}/pieces/{piece_id}/scores`` + — score a single message piece. + +All scoring is delegated to ``ScoringService``, which itself calls ``Scorer.score_async`` +so the resulting scores are persisted in PyRIT memory and surfaced automatically by +``GET /attacks/{id}/messages`` on the next refresh. +""" + +import logging + +from fastapi import APIRouter, HTTPException, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.scoring import ( + ScoreConversationRequest, + ScoreMessageRequest, + ScoreResponse, + ScorerListResponse, +) +from pyrit.backend.services.scoring_service import get_scoring_service + +logger = logging.getLogger(__name__) + +scorers_router = APIRouter(prefix="/scorers", tags=["scorers"]) +attack_scoring_router = APIRouter(prefix="/attacks", tags=["attacks"]) + + +@scorers_router.get( + "", + response_model=ScorerListResponse, +) +async def list_scorers() -> ScorerListResponse: # pyrit-async-suffix-exempt + """ + List every registered scorer instance. + + Returns: + ScorerListResponse: Registered scorers in registry-name order. + """ + service = get_scoring_service() + return await service.list_scorers_async() + + +@attack_scoring_router.post( + "/{attack_result_id}/conversations/{conversation_id}/scores", + response_model=ScoreResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid scoring request"}, + 404: {"model": ProblemDetail, "description": "Attack, conversation, or scorer not found"}, + }, +) +async def score_conversation( # pyrit-async-suffix-exempt + attack_result_id: str, + conversation_id: str, + request: ScoreConversationRequest, +) -> ScoreResponse: + """ + Score a conversation belonging to an attack with a registered scorer. + + Args: + attack_result_id (str): The AttackResult primary key. + conversation_id (str): The conversation to score (must belong to the attack). + request (ScoreConversationRequest): Scorer name, mode, and optional objective. + + Returns: + ScoreResponse: The scores produced by the scorer (also persisted to memory). + """ + service = get_scoring_service() + + try: + return await service.score_conversation_async( + attack_result_id=attack_result_id, + conversation_id=conversation_id, + request=request, + ) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception( + "Failed to score conversation '%s' on attack '%s'", conversation_id, attack_result_id + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error. Check server logs for details.", + ) from e + + +@attack_scoring_router.post( + "/{attack_result_id}/conversations/{conversation_id}/pieces/{piece_id}/scores", + response_model=ScoreResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid scoring request"}, + 404: {"model": ProblemDetail, "description": "Attack, conversation, piece, or scorer not found"}, + }, +) +async def score_message_piece( # pyrit-async-suffix-exempt + attack_result_id: str, + conversation_id: str, + piece_id: str, + request: ScoreMessageRequest, +) -> ScoreResponse: + """ + Score a single message piece with a registered scorer. + + Args: + attack_result_id (str): The AttackResult primary key. + conversation_id (str): The conversation containing the piece. + piece_id (str): The message-piece id to score. + request (ScoreMessageRequest): Scorer name and optional objective. + + Returns: + ScoreResponse: The scores produced by the scorer (also persisted to memory). + """ + service = get_scoring_service() + + try: + return await service.score_message_async( + attack_result_id=attack_result_id, + conversation_id=conversation_id, + piece_id=piece_id, + request=request, + ) + except LookupError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except Exception as e: + logger.exception( + "Failed to score piece '%s' on conversation '%s' (attack '%s')", + piece_id, + conversation_id, + attack_result_id, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error. Check server logs for details.", + ) from e diff --git a/pyrit/backend/services/scoring_service.py b/pyrit/backend/services/scoring_service.py new file mode 100644 index 0000000000..5348754b66 --- /dev/null +++ b/pyrit/backend/services/scoring_service.py @@ -0,0 +1,273 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scoring service for invoking registered scorers on demand. + +This service is the thin glue between the REST surface and ``Scorer.score_async``: + +* ``list_scorers_async`` enumerates ``ScorerRegistry`` so the GUI can populate a dropdown. +* ``score_conversation_async`` resolves a scorer by registry name and applies it to either + the last assistant message in a conversation or the whole concatenated transcript + (via ``create_conversation_scorer``). +* ``score_message_async`` scores a single message piece in a conversation. + +All scoring runs through ``Scorer.score_async`` which persists scores to memory, so a +subsequent ``GET /attacks/{id}/messages`` call will surface the new scores on the +``BackendMessagePiece.scores`` field with no additional work here. +""" + +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import TYPE_CHECKING + +from pyrit.backend.mappers import pyrit_scores_to_dto +from pyrit.backend.models.attacks import Score +from pyrit.backend.models.scoring import ( + ScoreConversationMode, + ScoreConversationRequest, + ScoreMessageRequest, + ScoreResponse, + ScorerListResponse, + ScorerSummary, +) +from pyrit.memory import CentralMemory +from pyrit.registry import ScorerRegistry + +if TYPE_CHECKING: + from pyrit.models import Message + from pyrit.score.scorer import Scorer + +logger = logging.getLogger(__name__) + + +def _extract_class_description(cls: type) -> str | None: + """ + Extract the first paragraph of a class docstring as a short human-readable description. + + Matches the convention used by ``ConverterService.list_converter_catalog_async`` so the + UI can render scorer and converter info consistently. + """ + raw_doc = (cls.__doc__ or "").strip() + if not raw_doc: + return None + first_paragraph = raw_doc.split("\n\n")[0] + cleaned = " ".join(line.strip() for line in first_paragraph.splitlines() if line.strip()) + return cleaned or None + + +class ScoringService: + """ + Service that surfaces registered scorers and runs them against stored conversations. + + Scoring writes to memory via ``Scorer.score_async``, so callers do not need to + persist the returned ``Score`` DTOs themselves. + """ + + def __init__(self) -> None: + """Initialize the scoring service.""" + self._memory = CentralMemory.get_memory_instance() + self._registry = ScorerRegistry.get_registry_singleton() + + async def list_scorers_async(self) -> ScorerListResponse: # pyrit-async-suffix-exempt + """ + Enumerate every registered scorer (registry name, class, score type, description, tags). + + Returns: + ScorerListResponse: Registered scorers in registry-name order. + """ + items = [ + ScorerSummary( + scorer_registry_name=entry.name, + scorer_type=entry.instance.__class__.__name__, + score_type=entry.instance.scorer_type, + description=_extract_class_description(entry.instance.__class__), + tags=sorted(entry.tags.keys()) if entry.tags else [], + ) + for entry in self._registry.get_all_instances() + ] + return ScorerListResponse(items=items) + + async def score_conversation_async( + self, + *, + attack_result_id: str, + conversation_id: str, + request: ScoreConversationRequest, + ) -> ScoreResponse: + """ + Score a conversation belonging to an attack with a registered scorer. + + Args: + attack_result_id (str): The AttackResult primary key (used to verify existence). + conversation_id (str): The conversation to score (must belong to the attack). + request (ScoreConversationRequest): Scorer name, mode, and optional objective. + + Returns: + ScoreResponse: The scores produced by the scorer (also persisted to memory). + + Raises: + LookupError: If the attack does not exist. + ValueError: If the conversation does not belong to the attack, the conversation + has no scoreable assistant message, or the scorer registry name is unknown. + """ + self._verify_conversation_belongs_to_attack( + attack_result_id=attack_result_id, conversation_id=conversation_id + ) + + scorer = self._resolve_scorer(request.scorer_registry_name) + conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) + + if not conversation: + raise ValueError(f"Conversation '{conversation_id}' has no messages to score") + + target_message = self._select_message_for_scoring(conversation=conversation, mode=request.mode) + effective_scorer = self._maybe_wrap_for_conversation_scoring(scorer=scorer, mode=request.mode) + + scores = await effective_scorer.score_async(message=target_message, objective=request.objective) + return ScoreResponse(scores=pyrit_scores_to_dto(list(scores))) + + async def score_message_async( + self, + *, + attack_result_id: str, + conversation_id: str, + piece_id: str, + request: ScoreMessageRequest, + ) -> ScoreResponse: + """ + Score a single message piece in a conversation with a registered scorer. + + Args: + attack_result_id (str): The AttackResult primary key (used to verify existence). + conversation_id (str): The conversation containing the piece. + piece_id (str): The message-piece id to score. + request (ScoreMessageRequest): Scorer name and optional objective. + + Returns: + ScoreResponse: The scores produced by the scorer (also persisted to memory). + + Raises: + LookupError: If the attack does not exist, or the piece is not in the conversation. + ValueError: If the conversation does not belong to the attack or the scorer is unknown. + """ + self._verify_conversation_belongs_to_attack( + attack_result_id=attack_result_id, conversation_id=conversation_id + ) + + scorer = self._resolve_scorer(request.scorer_registry_name) + conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) + + target_message = self._find_message_containing_piece(conversation=conversation, piece_id=piece_id) + if target_message is None: + raise LookupError( + f"Message piece '{piece_id}' is not part of conversation '{conversation_id}'" + ) + + scores = await scorer.score_async(message=target_message, objective=request.objective) + return ScoreResponse(scores=pyrit_scores_to_dto(list(scores))) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _verify_conversation_belongs_to_attack( + self, *, attack_result_id: str, conversation_id: str + ) -> None: + """ + Raise ``LookupError`` if the attack does not exist, ``ValueError`` if the + conversation does not belong to it. + """ + results = self._memory.get_attack_results(attack_result_ids=[attack_result_id]) + if not results: + raise LookupError(f"Attack '{attack_result_id}' not found") + if conversation_id not in results[0].get_active_conversation_ids(): + raise ValueError( + f"Conversation '{conversation_id}' is not part of attack '{attack_result_id}'" + ) + + def _resolve_scorer(self, scorer_registry_name: str) -> Scorer: + """Resolve a scorer by registry name; raise ``ValueError`` when missing.""" + scorer = self._registry.get(scorer_registry_name) + if scorer is None: + raise ValueError(f"Scorer '{scorer_registry_name}' is not registered") + return scorer + + @staticmethod + def _select_message_for_scoring( + *, conversation: list[Message], mode: ScoreConversationMode + ) -> Message: + """ + Pick the message to hand to ``Scorer.score_async``. + + For ``last_message`` we score only the most recent assistant turn so the result + is comparable to a per-message score. For ``whole_conversation`` we just pick the + last message in the conversation — the ``ConversationScorer`` wrapper uses its + ``conversation_id`` to fetch the full transcript from memory. + """ + if mode == "whole_conversation": + return conversation[-1] + + # last_message: find the most recent assistant (or simulated assistant) turn. + for message in reversed(conversation): + if message.message_pieces and message.message_pieces[0].role in ( + "assistant", + "simulated_assistant", + ): + return message + raise ValueError("Conversation has no assistant message to score") + + @staticmethod + def _maybe_wrap_for_conversation_scoring( + *, scorer: Scorer, mode: ScoreConversationMode + ) -> Scorer: + """ + Wrap the scorer in a ``ConversationScorer`` when the caller asked for + whole-conversation scoring. Raises ``ValueError`` if the scorer cannot be wrapped + (i.e. it isn't a ``FloatScaleScorer`` or ``TrueFalseScorer``). + """ + if mode != "whole_conversation": + return scorer + + from pyrit.score.conversation_scorer import create_conversation_scorer + from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer + from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + + if not isinstance(scorer, (FloatScaleScorer, TrueFalseScorer)): + raise ValueError( + "Whole-conversation scoring requires a FloatScaleScorer or TrueFalseScorer; " + f"got {type(scorer).__name__}" + ) + return create_conversation_scorer(scorer=scorer) + + @staticmethod + def _find_message_containing_piece( + *, conversation: list[Message], piece_id: str + ) -> Message | None: + """Return the message in ``conversation`` whose pieces include ``piece_id``.""" + for message in conversation: + for piece in message.message_pieces: + if str(piece.id) == piece_id: + return message + return None + + +# ============================================================================ +# Singleton +# ============================================================================ + + +@lru_cache(maxsize=1) +def get_scoring_service() -> ScoringService: + """ + Get the global scoring service instance. + + Returns: + ScoringService: The singleton ``ScoringService`` instance. + """ + return ScoringService() + + +__all__ = ["ScoringService", "get_scoring_service", "Score"] diff --git a/tests/unit/backend/test_scoring_service.py b/tests/unit/backend/test_scoring_service.py new file mode 100644 index 0000000000..fb3a0a3aee --- /dev/null +++ b/tests/unit/backend/test_scoring_service.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the scoring service. + +Mocks ``ScorerRegistry``, ``CentralMemory``, and the per-scorer ``score_async`` to +exercise the orchestration logic in isolation. +""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.backend.models.scoring import ( + ScoreConversationRequest, + ScoreMessageRequest, +) +from pyrit.backend.services.scoring_service import ( + ScoringService, + get_scoring_service, +) +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, build_atomic_attack_identifier +from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + + +@pytest.fixture +def mock_memory(): + memory = MagicMock() + memory.get_attack_results.return_value = [] + memory.get_conversation.return_value = [] + return memory + + +@pytest.fixture +def mock_registry(): + registry = MagicMock() + registry.get.return_value = None + registry.get_all_instances.return_value = [] + return registry + + +@pytest.fixture +def scoring_service(mock_memory, mock_registry): + with patch("pyrit.backend.services.scoring_service.CentralMemory") as mock_central, patch( + "pyrit.backend.services.scoring_service.ScorerRegistry" + ) as mock_registry_cls: + mock_central.get_memory_instance.return_value = mock_memory + mock_registry_cls.get_registry_singleton.return_value = mock_registry + # Bypass lru_cache so each test gets a fresh service instance bound to the mocks above. + get_scoring_service.cache_clear() + service = ScoringService() + yield service + get_scoring_service.cache_clear() + + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # + + +def _make_attack_result(*, conversation_id: str = "conv-1", attack_result_id: str = "ar-1") -> AttackResult: + target_identifier = ComponentIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target", + ) + now = datetime.now(timezone.utc) + return AttackResult( + conversation_id=conversation_id, + objective="Test", + atomic_attack_identifier=build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name="ManualAttack", + class_module="pyrit.backend", + children={"objective_target": target_identifier}, + ), + ), + outcome=AttackOutcome.UNDETERMINED, + attack_result_id=attack_result_id, + metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, + labels={}, + ) + + +def _make_piece(*, role: str = "assistant", piece_id: str | None = None) -> MagicMock: + piece = MagicMock() + piece.id = piece_id or uuid.uuid4() + piece.role = role + piece.api_role = "assistant" if role in ("assistant", "simulated_assistant") else role + piece.scores = [] + return piece + + +def _make_message(pieces: list[MagicMock]) -> MagicMock: + msg = MagicMock() + msg.message_pieces = pieces + return msg + + +def _make_pyrit_score(*, value: str = "true", category: str = "harm") -> MagicMock: + score = MagicMock() + score.id = uuid.uuid4() + score.scorer_class_identifier = ComponentIdentifier( + class_name="FakeScorer", + class_module="tests", + ) + score.score_type = "true_false" + score.score_value = value + score.score_category = [category] + score.score_rationale = "because" + score.timestamp = datetime.now(timezone.utc) + return score + + +# --------------------------------------------------------------------------- # +# list_scorers_async +# --------------------------------------------------------------------------- # + + +class TestListScorers: + async def test_returns_empty_when_no_scorers(self, scoring_service, mock_registry) -> None: + mock_registry.get_all_instances.return_value = [] + + result = await scoring_service.list_scorers_async() + + assert result.items == [] + + async def test_returns_registered_scorers(self, scoring_service, mock_registry) -> None: + scorer = MagicMock(spec=TrueFalseScorer) + scorer.scorer_type = "true_false" + entry = MagicMock() + entry.name = "my-scorer" + entry.instance = scorer + entry.tags = {"refusal": "", "best_refusal": ""} + mock_registry.get_all_instances.return_value = [entry] + + result = await scoring_service.list_scorers_async() + + assert len(result.items) == 1 + item = result.items[0] + assert item.scorer_registry_name == "my-scorer" + assert item.score_type == "true_false" + assert sorted(item.tags) == ["best_refusal", "refusal"] + # MagicMock(spec=TrueFalseScorer) inherits TrueFalseScorer.__doc__, + # so description should come from the real class docstring (first paragraph). + assert item.description and len(item.description) > 0 + + async def test_description_falls_back_to_none_when_class_has_no_docstring( + self, scoring_service, mock_registry + ) -> None: + class _Undocumented: + pass + + scorer = MagicMock() + scorer.scorer_type = "true_false" + scorer.__class__ = _Undocumented + entry = MagicMock() + entry.name = "undoc" + entry.instance = scorer + entry.tags = {} + mock_registry.get_all_instances.return_value = [entry] + + result = await scoring_service.list_scorers_async() + assert result.items[0].description is None + assert result.items[0].tags == [] + + +# --------------------------------------------------------------------------- # +# score_conversation_async +# --------------------------------------------------------------------------- # + + +class TestScoreConversation: + async def test_raises_when_attack_missing(self, scoring_service, mock_memory) -> None: + mock_memory.get_attack_results.return_value = [] + + with pytest.raises(LookupError, match="not found"): + await scoring_service.score_conversation_async( + attack_result_id="missing", + conversation_id="conv-1", + request=ScoreConversationRequest(scorer_registry_name="x"), + ) + + async def test_raises_when_conversation_not_in_attack(self, scoring_service, mock_memory) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result(conversation_id="conv-1")] + + with pytest.raises(ValueError, match="not part of attack"): + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="other-conv", + request=ScoreConversationRequest(scorer_registry_name="x"), + ) + + async def test_raises_when_scorer_missing( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_registry.get.return_value = None + + with pytest.raises(ValueError, match="not registered"): + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest(scorer_registry_name="missing-scorer"), + ) + + async def test_raises_when_conversation_empty( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [] + mock_registry.get.return_value = MagicMock(spec=TrueFalseScorer) + + with pytest.raises(ValueError, match="no messages to score"): + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest(scorer_registry_name="x"), + ) + + async def test_raises_when_last_message_has_no_assistant_turn( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [_make_message([_make_piece(role="user")])] + mock_registry.get.return_value = MagicMock(spec=TrueFalseScorer) + + with pytest.raises(ValueError, match="no assistant message"): + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest(scorer_registry_name="x"), + ) + + async def test_last_message_scores_most_recent_assistant_turn( + self, scoring_service, mock_memory, mock_registry + ) -> None: + user_msg = _make_message([_make_piece(role="user")]) + first_assistant = _make_message([_make_piece(role="assistant")]) + last_assistant = _make_message([_make_piece(role="assistant")]) + trailing_user = _make_message([_make_piece(role="user")]) + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [user_msg, first_assistant, user_msg, last_assistant, trailing_user] + + scorer = MagicMock(spec=TrueFalseScorer) + scorer.score_async = AsyncMock(return_value=[_make_pyrit_score()]) + mock_registry.get.return_value = scorer + + result = await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest(scorer_registry_name="my-scorer", objective="be helpful"), + ) + + scorer.score_async.assert_awaited_once() + kwargs = scorer.score_async.await_args.kwargs + assert kwargs["message"] is last_assistant + assert kwargs["objective"] == "be helpful" + assert len(result.scores) == 1 + assert result.scores[0].score_value == "true" + + async def test_whole_conversation_wraps_scorer( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + # Whole-conv mode just hands the last message to the wrapped scorer; content doesn't matter. + last = _make_message([_make_piece(role="assistant")]) + mock_memory.get_conversation.return_value = [last] + + scorer = MagicMock(spec=FloatScaleScorer) + mock_registry.get.return_value = scorer + + with patch( + "pyrit.score.conversation_scorer.create_conversation_scorer" + ) as mock_create: + wrapped = MagicMock() + wrapped.score_async = AsyncMock(return_value=[_make_pyrit_score()]) + mock_create.return_value = wrapped + + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest( + scorer_registry_name="my-scorer", mode="whole_conversation" + ), + ) + + mock_create.assert_called_once_with(scorer=scorer) + wrapped.score_async.assert_awaited_once() + + async def test_whole_conversation_rejects_unsupported_scorer( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [_make_message([_make_piece(role="assistant")])] + mock_registry.get.return_value = MagicMock() # Not a FloatScale/TrueFalse scorer. + + with pytest.raises(ValueError, match="FloatScaleScorer or TrueFalseScorer"): + await scoring_service.score_conversation_async( + attack_result_id="ar-1", + conversation_id="conv-1", + request=ScoreConversationRequest( + scorer_registry_name="my-scorer", mode="whole_conversation" + ), + ) + + +# --------------------------------------------------------------------------- # +# score_message_async +# --------------------------------------------------------------------------- # + + +class TestScoreMessage: + async def test_scores_specific_piece(self, scoring_service, mock_memory, mock_registry) -> None: + target_piece = _make_piece(role="assistant", piece_id="piece-target") + other_piece = _make_piece(role="assistant", piece_id="piece-other") + target_msg = _make_message([target_piece]) + other_msg = _make_message([other_piece]) + + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [other_msg, target_msg] + + scorer = MagicMock(spec=TrueFalseScorer) + scorer.score_async = AsyncMock(return_value=[_make_pyrit_score()]) + mock_registry.get.return_value = scorer + + result = await scoring_service.score_message_async( + attack_result_id="ar-1", + conversation_id="conv-1", + piece_id="piece-target", + request=ScoreMessageRequest(scorer_registry_name="my-scorer"), + ) + + scorer.score_async.assert_awaited_once() + assert scorer.score_async.await_args.kwargs["message"] is target_msg + assert len(result.scores) == 1 + + async def test_raises_when_piece_not_in_conversation( + self, scoring_service, mock_memory, mock_registry + ) -> None: + mock_memory.get_attack_results.return_value = [_make_attack_result()] + mock_memory.get_conversation.return_value = [ + _make_message([_make_piece(role="assistant", piece_id="other")]) + ] + mock_registry.get.return_value = MagicMock(spec=TrueFalseScorer) + + with pytest.raises(LookupError, match="not part of conversation"): + await scoring_service.score_message_async( + attack_result_id="ar-1", + conversation_id="conv-1", + piece_id="missing-piece", + request=ScoreMessageRequest(scorer_registry_name="x"), + ) From 597b8d898840f3a94084856b16649683cdfd9c89 Mon Sep 17 00:00:00 2001 From: jbolor21 <86250273+jbolor21@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:48:26 -0700 Subject: [PATCH 2/3] cleaning up UI, added uses_objective flag to scorers --- frontend/src/components/Chat/ChatWindow.tsx | 28 ++- .../src/components/Chat/ScoreDialog.test.tsx | 204 +++++++++++++++++- frontend/src/components/Chat/ScoreDialog.tsx | 119 +++++++--- frontend/src/types/index.ts | 1 + pyrit/backend/models/scoring.py | 9 + pyrit/backend/services/scoring_service.py | 33 +-- pyrit/score/conversation_scorer.py | 5 + .../float_scale/audio_float_scale_scorer.py | 5 + .../self_ask_general_float_scale_scorer.py | 1 + .../float_scale/self_ask_scale_scorer.py | 1 + .../float_scale/video_float_scale_scorer.py | 13 ++ pyrit/score/scorer.py | 11 + .../true_false/audio_true_false_scorer.py | 5 + .../float_scale_threshold_scorer.py | 5 + .../self_ask_general_true_false_scorer.py | 1 + .../self_ask_question_answer_scorer.py | 1 + .../true_false/self_ask_refusal_scorer.py | 1 + .../true_false/self_ask_true_false_scorer.py | 1 + .../true_false/true_false_composite_scorer.py | 6 +- .../true_false/true_false_inverter_scorer.py | 5 + .../true_false/video_true_false_scorer.py | 13 ++ tests/unit/backend/test_scoring_service.py | 63 +++--- 22 files changed, 445 insertions(+), 86 deletions(-) diff --git a/frontend/src/components/Chat/ChatWindow.tsx b/frontend/src/components/Chat/ChatWindow.tsx index ad12765f15..c8205295e7 100644 --- a/frontend/src/components/Chat/ChatWindow.tsx +++ b/frontend/src/components/Chat/ChatWindow.tsx @@ -76,6 +76,14 @@ export default function ChatWindow({ const [pieceConversions, setPieceConversions] = useState>({}) const [panelRefreshKey, setPanelRefreshKey] = useState(0) const [scoreTarget, setScoreTarget] = useState(null) + // Last-used scorer per conversation id. Lets the score dialog pre-select the + // scorer the user previously picked for the same conversation. Persists for + // the lifetime of the ChatWindow (not across page reloads); the user can + // still pick a different scorer at any time. + const [scorerByConversation, setScorerByConversation] = useState>({}) + // Last-typed objective per conversation id. Mirrors scorerByConversation so + // re-opening the dialog pre-fills the objective the user previously typed. + const [objectiveByConversation, setObjectiveByConversation] = useState>({}) const inputBoxRef = useRef(null) const handleAttachmentsChange = useCallback((types: string[], data: Record) => { @@ -611,7 +619,7 @@ export default function ChatWindow({ data-testid="score-conversation-btn" aria-label="Score conversation" > - Score + Score conversation @@ -722,6 +730,24 @@ export default function ChatWindow({ target={scoreTarget} onClose={() => setScoreTarget(null)} onScored={() => { setScoreTarget(null); handleScored() }} + initialScorerName={scoreTarget ? scorerByConversation[scoreTarget.conversationId] : undefined} + onScorerSelected={(name) => { + if (!scoreTarget) return + setScorerByConversation((prev) => + prev[scoreTarget.conversationId] === name + ? prev + : { ...prev, [scoreTarget.conversationId]: name } + ) + }} + initialObjective={scoreTarget ? objectiveByConversation[scoreTarget.conversationId] : undefined} + onObjectiveChange={(value) => { + if (!scoreTarget) return + setObjectiveByConversation((prev) => + prev[scoreTarget.conversationId] === value + ? prev + : { ...prev, [scoreTarget.conversationId]: value } + ) + }} /> ) diff --git a/frontend/src/components/Chat/ScoreDialog.test.tsx b/frontend/src/components/Chat/ScoreDialog.test.tsx index 7897c275fc..271560fcba 100644 --- a/frontend/src/components/Chat/ScoreDialog.test.tsx +++ b/frontend/src/components/Chat/ScoreDialog.test.tsx @@ -1,4 +1,4 @@ -import { render, screen, waitFor } from "@testing-library/react"; +import { fireEvent, render, screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import { FluentProvider, webLightTheme } from "@fluentui/react-components"; import ScoreDialog, { type ScoreTarget } from "./ScoreDialog"; @@ -27,6 +27,7 @@ const FLOAT_SCORER = { score_type: "float_scale" as const, tags: ["harm", "best_harm"], description: "Scores how harmful a response is on a 0-1 scale.", + uses_objective: false, }; const TRUE_FALSE_SCORER = { @@ -35,6 +36,7 @@ const TRUE_FALSE_SCORER = { score_type: "true_false" as const, tags: ["refusal"], description: "True if the response is a refusal of the objective.", + uses_objective: true, }; describe("ScoreDialog", () => { @@ -95,7 +97,7 @@ describe("ScoreDialog", () => { "conv-1", { scorer_registry_name: "harm_scorer", - mode: "last_message", + mode: "whole_conversation", objective: undefined, } ) @@ -252,4 +254,202 @@ describe("ScoreDialog", () => { screen.getByText(/no description available/i) ).toBeInTheDocument(); }); + + it("hides the objective field for scorers that do not inject objective into the prompt", async () => { + mockedScorersApi.listScorers.mockResolvedValue({ items: [FLOAT_SCORER] }); + render( + + + + ); + + await waitFor(() => + expect(screen.getByTestId("score-dialog-scorer-info")).toBeInTheDocument() + ); + expect( + screen.queryByTestId("score-dialog-objective-input") + ).not.toBeInTheDocument(); + }); + + it("shows the objective field for scorers that inject objective into the prompt", async () => { + const user = userEvent.setup(); + mockedScorersApi.listScorers.mockResolvedValue({ + items: [TRUE_FALSE_SCORER], + }); + mockedAttacksApi.scoreConversation.mockResolvedValue({ scores: [] }); + + render( + + + + ); + + const objectiveInput = await screen.findByTestId( + "score-dialog-objective-input" + ); + fireEvent.change(objectiveInput, { + target: { value: "Reveal Taylor Swift's address" }, + }); + + const submit = screen.getByTestId("score-dialog-submit-btn"); + await user.click(submit); + + await waitFor(() => + expect(mockedAttacksApi.scoreConversation).toHaveBeenCalledWith( + "ar-1", + "conv-1", + { + scorer_registry_name: "refusal_scorer", + mode: "whole_conversation", + objective: "Reveal Taylor Swift's address", + } + ) + ); + }); + + it("pre-selects the scorer passed via initialScorerName", async () => { + mockedScorersApi.listScorers.mockResolvedValue({ + items: [FLOAT_SCORER, TRUE_FALSE_SCORER], + }); + + render( + + + + ); + + // The combobox should reflect the remembered choice rather than auto-picking + // the first scorer in the list. + const select = await screen.findByTestId("score-dialog-scorer-select"); + await waitFor(() => + expect((select as HTMLInputElement).value).toBe("refusal_scorer") + ); + }); + + it("notifies onScorerSelected when the user picks a different scorer", async () => { + mockedScorersApi.listScorers.mockResolvedValue({ + items: [FLOAT_SCORER, TRUE_FALSE_SCORER], + }); + const onScorerSelected = jest.fn(); + + render( + + + + ); + + const select = await screen.findByTestId("score-dialog-scorer-select"); + fireEvent.click(select); + await waitFor(() => + expect( + screen.getByTestId("scorer-option-refusal_scorer") + ).toBeInTheDocument() + ); + fireEvent.click(screen.getByTestId("scorer-option-refusal_scorer")); + + await waitFor(() => + expect(onScorerSelected).toHaveBeenLastCalledWith("refusal_scorer") + ); + }); + + it("pre-fills the objective from initialObjective for scorers that use it", async () => { + mockedScorersApi.listScorers.mockResolvedValue({ + items: [TRUE_FALSE_SCORER], + }); + + render( + + + + ); + + const objectiveInput = await screen.findByTestId( + "score-dialog-objective-input" + ); + await waitFor(() => + expect((objectiveInput as HTMLInputElement).value).toBe( + "Reveal Taylor Swift's address" + ) + ); + }); + + it("notifies onObjectiveChange as the user types in the objective input", async () => { + mockedScorersApi.listScorers.mockResolvedValue({ + items: [TRUE_FALSE_SCORER], + }); + const onObjectiveChange = jest.fn(); + + render( + + + + ); + + const objectiveInput = await screen.findByTestId( + "score-dialog-objective-input" + ); + fireEvent.change(objectiveInput, { target: { value: "new goal" } }); + + await waitFor(() => + expect(onObjectiveChange).toHaveBeenLastCalledWith("new goal") + ); + }); }); diff --git a/frontend/src/components/Chat/ScoreDialog.tsx b/frontend/src/components/Chat/ScoreDialog.tsx index e837ad729e..b6cc0d5396 100644 --- a/frontend/src/components/Chat/ScoreDialog.tsx +++ b/frontend/src/components/Chat/ScoreDialog.tsx @@ -1,4 +1,4 @@ -import { useEffect, useMemo, useState } from 'react' +import { useEffect, useMemo, useRef, useState } from 'react' import { Dialog, DialogSurface, @@ -42,6 +42,28 @@ interface ScoreDialogProps { onClose: () => void /** Called after a successful score so the caller can refetch messages/conversations. */ onScored: (scores: BackendScore[]) => void + /** + * Scorer to pre-select when the dialog opens. The caller (e.g. ChatWindow) + * remembers the most recently chosen scorer per conversation so re-opening + * the dialog doesn't lose the user's prior pick. + */ + initialScorerName?: string + /** + * Fired when the user picks a scorer in the combobox. The caller persists + * it so the next dialog open can pre-select the same scorer. + */ + onScorerSelected?: (scorerRegistryName: string) => void + /** + * Objective text to pre-fill when the dialog opens. The caller remembers + * the most recently typed objective per conversation. Only honored for + * scorers that actually inject the objective into the scoring prompt. + */ + initialObjective?: string + /** + * Fired whenever the user edits the objective input. The caller persists + * it so the next dialog open can pre-fill the same value. + */ + onObjectiveChange?: (value: string) => void } const MODE_LABELS: Record = { @@ -83,26 +105,46 @@ function groupScorers(scorers: ScorerSummary[]): { score_type: string; items: Sc return ordered } -export default function ScoreDialog({ open, target, onClose, onScored }: ScoreDialogProps) { +export default function ScoreDialog({ + open, + target, + onClose, + onScored, + initialScorerName, + onScorerSelected, + initialObjective, + onObjectiveChange, +}: ScoreDialogProps) { const [scorers, setScorers] = useState([]) const [loadingScorers, setLoadingScorers] = useState(false) const [loadError, setLoadError] = useState(null) const [selectedScorerName, setSelectedScorerName] = useState('') const [scorerQuery, setScorerQuery] = useState('') - const [mode, setMode] = useState('last_message') + const [mode, setMode] = useState('whole_conversation') const [objective, setObjective] = useState('') const [submitting, setSubmitting] = useState(false) const [submitError, setSubmitError] = useState(null) const isConversationScope = target?.kind === 'conversation' + // Snapshot of initialScorerName read at open time only. Using a ref means + // the reset effect below doesn't re-fire (and wipe user edits) every time + // the parent updates the cached scorer name as the user picks options. + const initialScorerNameRef = useRef(initialScorerName) + const initialObjectiveRef = useRef(initialObjective) + useEffect(() => { + initialScorerNameRef.current = initialScorerName + initialObjectiveRef.current = initialObjective + }) + // Reset form whenever the dialog re-opens against a new target. useEffect(() => { if (!open) return - setSelectedScorerName('') - setScorerQuery('') - setMode('last_message') - setObjective('') + const seed = initialScorerNameRef.current ?? '' + setSelectedScorerName(seed) + setScorerQuery(seed) + setMode('whole_conversation') + setObjective(initialObjectiveRef.current ?? '') setSubmitError(null) }, [open, target]) @@ -161,12 +203,20 @@ export default function ScoreDialog({ open, target, onClose, onScored }: ScoreDi } }, [wholeConversationDisabled, mode]) + // Most scorers don't actually inject the objective into the scoring prompt + // (it's only attached to the resulting Score row as metadata). We hide the + // input for those scorers but keep the typed value in state so that toggling + // back to an injecting scorer (or persisting via onObjectiveChange) doesn't + // wipe what the user typed. Submission also gates on this flag so a hidden + // stale value can never reach the backend. + const scorerUsesObjective = selectedScorer?.uses_objective === true + const handleSubmit = async () => { if (!target || !selectedScorerName) return setSubmitting(true) setSubmitError(null) try { - const trimmedObjective = objective.trim() || undefined + const trimmedObjective = scorerUsesObjective ? objective.trim() || undefined : undefined if (target.kind === 'conversation') { const response = await attacksApi.scoreConversation( target.attackResultId, @@ -262,8 +312,10 @@ export default function ScoreDialog({ open, target, onClose, onScored }: ScoreDi value={scorerQuery} selectedOptions={selectedScorerName ? [selectedScorerName] : []} onOptionSelect={(_, data) => { - setSelectedScorerName(data.optionValue ?? '') + const name = data.optionValue ?? '' + setSelectedScorerName(name) setScorerQuery(data.optionText ?? '') + if (name) onScorerSelected?.(name) }} onChange={(e) => setScorerQuery((e.target as HTMLInputElement).value)} placeholder="Search scorers..." @@ -349,28 +401,33 @@ export default function ScoreDialog({ open, target, onClose, onScored }: ScoreDi )} - - Objective (optional) - - - - - - - } - > - setObjective(data.value)} - placeholder="e.g. The model agreed to draft a phishing email" - data-testid="score-dialog-objective-input" - /> - + {scorerUsesObjective && ( + + Objective (optional) + + + + + + + } + > + { + setObjective(data.value) + onObjectiveChange?.(data.value) + }} + placeholder="e.g. The model agreed to draft a phishing email" + data-testid="score-dialog-objective-input" + /> + + )} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 743cdcc8cd..6f538bc0ee 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -292,6 +292,7 @@ export interface ScorerSummary { score_type: ScorerScoreType description?: string | null tags?: string[] + uses_objective?: boolean } export interface ScorerListResponse { diff --git a/pyrit/backend/models/scoring.py b/pyrit/backend/models/scoring.py index 55f7cf0514..f8ea92891f 100644 --- a/pyrit/backend/models/scoring.py +++ b/pyrit/backend/models/scoring.py @@ -46,6 +46,15 @@ class ScorerSummary(BaseModel): default_factory=list, description="Registry tags (e.g. 'refusal', 'best_refusal'). Used in the GUI for grouping/badges.", ) + uses_objective: bool = Field( + False, + description=( + "True if this scorer injects the caller-supplied objective into its scoring prompt so the " + "judge LLM is conditioned on it. When False, the objective is only stored on the resulting " + "Score row as metadata and has no effect on the scorer's verdict. Read off " + "``Scorer.uses_objective``. The GUI hides the objective input for scorers where this is False." + ), + ) class ScorerListResponse(BaseModel): diff --git a/pyrit/backend/services/scoring_service.py b/pyrit/backend/services/scoring_service.py index 5348754b66..75500b02a9 100644 --- a/pyrit/backend/services/scoring_service.py +++ b/pyrit/backend/services/scoring_service.py @@ -85,6 +85,7 @@ async def list_scorers_async(self) -> ScorerListResponse: # pyrit-async-suffix- score_type=entry.instance.scorer_type, description=_extract_class_description(entry.instance.__class__), tags=sorted(entry.tags.keys()) if entry.tags else [], + uses_objective=bool(entry.instance.uses_objective), ) for entry in self._registry.get_all_instances() ] @@ -113,9 +114,7 @@ async def score_conversation_async( ValueError: If the conversation does not belong to the attack, the conversation has no scoreable assistant message, or the scorer registry name is unknown. """ - self._verify_conversation_belongs_to_attack( - attack_result_id=attack_result_id, conversation_id=conversation_id - ) + self._verify_conversation_belongs_to_attack(attack_result_id=attack_result_id, conversation_id=conversation_id) scorer = self._resolve_scorer(request.scorer_registry_name) conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) @@ -153,18 +152,14 @@ async def score_message_async( LookupError: If the attack does not exist, or the piece is not in the conversation. ValueError: If the conversation does not belong to the attack or the scorer is unknown. """ - self._verify_conversation_belongs_to_attack( - attack_result_id=attack_result_id, conversation_id=conversation_id - ) + self._verify_conversation_belongs_to_attack(attack_result_id=attack_result_id, conversation_id=conversation_id) scorer = self._resolve_scorer(request.scorer_registry_name) conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) target_message = self._find_message_containing_piece(conversation=conversation, piece_id=piece_id) if target_message is None: - raise LookupError( - f"Message piece '{piece_id}' is not part of conversation '{conversation_id}'" - ) + raise LookupError(f"Message piece '{piece_id}' is not part of conversation '{conversation_id}'") scores = await scorer.score_async(message=target_message, objective=request.objective) return ScoreResponse(scores=pyrit_scores_to_dto(list(scores))) @@ -173,9 +168,7 @@ async def score_message_async( # Helpers # ------------------------------------------------------------------ - def _verify_conversation_belongs_to_attack( - self, *, attack_result_id: str, conversation_id: str - ) -> None: + def _verify_conversation_belongs_to_attack(self, *, attack_result_id: str, conversation_id: str) -> None: """ Raise ``LookupError`` if the attack does not exist, ``ValueError`` if the conversation does not belong to it. @@ -184,9 +177,7 @@ def _verify_conversation_belongs_to_attack( if not results: raise LookupError(f"Attack '{attack_result_id}' not found") if conversation_id not in results[0].get_active_conversation_ids(): - raise ValueError( - f"Conversation '{conversation_id}' is not part of attack '{attack_result_id}'" - ) + raise ValueError(f"Conversation '{conversation_id}' is not part of attack '{attack_result_id}'") def _resolve_scorer(self, scorer_registry_name: str) -> Scorer: """Resolve a scorer by registry name; raise ``ValueError`` when missing.""" @@ -196,9 +187,7 @@ def _resolve_scorer(self, scorer_registry_name: str) -> Scorer: return scorer @staticmethod - def _select_message_for_scoring( - *, conversation: list[Message], mode: ScoreConversationMode - ) -> Message: + def _select_message_for_scoring(*, conversation: list[Message], mode: ScoreConversationMode) -> Message: """ Pick the message to hand to ``Scorer.score_async``. @@ -220,9 +209,7 @@ def _select_message_for_scoring( raise ValueError("Conversation has no assistant message to score") @staticmethod - def _maybe_wrap_for_conversation_scoring( - *, scorer: Scorer, mode: ScoreConversationMode - ) -> Scorer: + def _maybe_wrap_for_conversation_scoring(*, scorer: Scorer, mode: ScoreConversationMode) -> Scorer: """ Wrap the scorer in a ``ConversationScorer`` when the caller asked for whole-conversation scoring. Raises ``ValueError`` if the scorer cannot be wrapped @@ -243,9 +230,7 @@ def _maybe_wrap_for_conversation_scoring( return create_conversation_scorer(scorer=scorer) @staticmethod - def _find_message_containing_piece( - *, conversation: list[Message], piece_id: str - ) -> Message | None: + def _find_message_containing_piece(*, conversation: list[Message], piece_id: str) -> Message | None: """Return the message in ``conversation`` whose pieces include ``piece_id``.""" for message in conversation: for piece in message.message_pieces: diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d921b2e1cf..9246178913 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -155,6 +155,11 @@ def validate_return_scores(self, scores: list[Score]) -> None: wrapped_scorer = self._get_wrapped_scorer() wrapped_scorer.validate_return_scores(scores) + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """Delegate to the wrapped scorer so the GUI's objective gating reflects the inner scorer.""" + return self._get_wrapped_scorer().uses_objective + def create_conversation_scorer( *, diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index 17653c9d5f..65fa7d67c7 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -73,3 +73,8 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st List of scores from evaluating the transcribed audio. """ return await self._audio_helper._score_audio_async(message_piece=message_piece, objective=objective) + + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """Delegate to the wrapped text scorer.""" + return self._audio_helper.text_scorer.uses_objective diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 17105defb9..368ef12fdf 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -25,6 +25,7 @@ class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): is_objective_required=True, ) TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 92db37a06a..6492714ef8 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -39,6 +39,7 @@ class SystemPaths(enum.Enum): is_objective_required=True, ) TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index 8e32bd9064..fc77640bd5 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -116,6 +116,19 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """True if either sub-scorer uses the objective AND its template enables objective flow-through.""" + image_uses = ( + self._video_helper.image_objective_template is not None and self._video_helper.image_scorer.uses_objective + ) + audio_uses = ( + self.audio_scorer is not None + and self._video_helper.audio_objective_template is not None + and self.audio_scorer.uses_objective + ) + return image_uses or audio_uses + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index f3cda9923b..3aa856ff70 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -83,6 +83,17 @@ class Scorer(Identifiable, abc.ABC): #: (Chat Completions API) and ``OpenAIResponseTarget`` (Responses API). score_blocked_content: bool = False + #: When True, this scorer injects the caller-supplied ``objective`` into the + #: scoring prompt (system or user message) so the judge LLM is conditioned + #: on it. When False, the ``objective`` is only attached to the resulting + #: ``Score`` row as metadata and does not influence the scorer's verdict. + #: + #: Surfaced in the GUI (``ScorerSummary.uses_objective``) so the + #: scoring dialog can hide the objective input for scorers that ignore it. + #: Wrapper scorers (composite, inverter, threshold, conversation, audio/video) + #: should override this with a property that delegates to the wrapped scorer. + uses_objective: bool = False + def __init_subclass__(cls, **kwargs: Any) -> None: """ Enforce the keyword-only constructor contract on subclasses. diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index 58397a3a29..341072ed63 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -73,3 +73,8 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st List of scores from evaluating the transcribed audio. """ return await self._audio_helper._score_audio_async(message_piece=message_piece, objective=objective) + + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """Delegate to the wrapped text scorer.""" + return self._audio_helper.text_scorer.uses_objective diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index 828b98a9dd..66a0ae0bcc 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -84,6 +84,11 @@ def get_chat_target(self) -> Optional["PromptTarget"]: """ return self._scorer.get_chat_target() + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """Delegate to the wrapped scorer.""" + return self._scorer.uses_objective + async def _score_async( self, message: Message, diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 71acd45a56..9163a8a2fd 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -29,6 +29,7 @@ class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): is_objective_required=False, ) TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index a2f5bc078e..082bc20680 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -33,6 +33,7 @@ class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): supported_data_types=["text"], is_objective_required=True, ) + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index b5a5c2b80c..9038f374ec 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -63,6 +63,7 @@ class SelfAskRefusalScorer(TrueFalseScorer): _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index bdb9fc21c3..ea19b7218d 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -105,6 +105,7 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): supported_data_types=["text", "image_path"], ) TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + uses_objective: bool = True def __init__( self, diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index 0fece73d64..0c7778b0fa 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -48,7 +48,6 @@ def __init__( if not scorers: raise ValueError("At least one scorer must be provided.") - for scorer in scorers: if not isinstance(scorer, TrueFalseScorer): raise ValueError("All scorers must be true_false scorers.") @@ -79,6 +78,11 @@ def get_chat_target(self) -> Optional["PromptTarget"]: return target return None + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """True if any child scorer injects the objective into its scoring prompt.""" + return any(s.uses_objective for s in self._scorers) + async def _score_async( self, message: Message, diff --git a/pyrit/score/true_false/true_false_inverter_scorer.py b/pyrit/score/true_false/true_false_inverter_scorer.py index c3b894edda..7013fce75a 100644 --- a/pyrit/score/true_false/true_false_inverter_scorer.py +++ b/pyrit/score/true_false/true_false_inverter_scorer.py @@ -58,6 +58,11 @@ def get_chat_target(self) -> Optional["PromptTarget"]: """ return self._scorer.get_chat_target() + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """Delegate to the wrapped scorer.""" + return self._scorer.uses_objective + async def _score_async( self, message: Message, diff --git a/pyrit/score/true_false/video_true_false_scorer.py b/pyrit/score/true_false/video_true_false_scorer.py index 5c45eae477..b8b43dffa0 100644 --- a/pyrit/score/true_false/video_true_false_scorer.py +++ b/pyrit/score/true_false/video_true_false_scorer.py @@ -93,6 +93,19 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) + @property + def uses_objective(self) -> bool: # type: ignore[ty:invalid-overload] + """True if either sub-scorer uses the objective AND its template enables objective flow-through.""" + image_uses = ( + self._video_helper.image_objective_template is not None and self._video_helper.image_scorer.uses_objective + ) + audio_uses = ( + self.audio_scorer is not None + and self._video_helper.audio_objective_template is not None + and self.audio_scorer.uses_objective + ) + return image_uses or audio_uses + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. diff --git a/tests/unit/backend/test_scoring_service.py b/tests/unit/backend/test_scoring_service.py index fb3a0a3aee..a00253be45 100644 --- a/tests/unit/backend/test_scoring_service.py +++ b/tests/unit/backend/test_scoring_service.py @@ -45,9 +45,10 @@ def mock_registry(): @pytest.fixture def scoring_service(mock_memory, mock_registry): - with patch("pyrit.backend.services.scoring_service.CentralMemory") as mock_central, patch( - "pyrit.backend.services.scoring_service.ScorerRegistry" - ) as mock_registry_cls: + with ( + patch("pyrit.backend.services.scoring_service.CentralMemory") as mock_central, + patch("pyrit.backend.services.scoring_service.ScorerRegistry") as mock_registry_cls, + ): mock_central.get_memory_instance.return_value = mock_memory mock_registry_cls.get_registry_singleton.return_value = mock_registry # Bypass lru_cache so each test gets a fresh service instance bound to the mocks above. @@ -167,6 +168,30 @@ class _Undocumented: assert result.items[0].description is None assert result.items[0].tags == [] + async def test_uses_objective_is_read_from_scorer_instance(self, scoring_service, mock_registry) -> None: + injecting = MagicMock(spec=TrueFalseScorer) + injecting.scorer_type = "true_false" + injecting.uses_objective = True + injecting_entry = MagicMock() + injecting_entry.name = "refusal" + injecting_entry.instance = injecting + injecting_entry.tags = {} + + non_injecting = MagicMock(spec=TrueFalseScorer) + non_injecting.scorer_type = "true_false" + non_injecting.uses_objective = False + non_injecting_entry = MagicMock() + non_injecting_entry.name = "substring" + non_injecting_entry.instance = non_injecting + non_injecting_entry.tags = {} + + mock_registry.get_all_instances.return_value = [injecting_entry, non_injecting_entry] + + result = await scoring_service.list_scorers_async() + by_name = {item.scorer_registry_name: item for item in result.items} + assert by_name["refusal"].uses_objective is True + assert by_name["substring"].uses_objective is False + # --------------------------------------------------------------------------- # # score_conversation_async @@ -194,9 +219,7 @@ async def test_raises_when_conversation_not_in_attack(self, scoring_service, moc request=ScoreConversationRequest(scorer_registry_name="x"), ) - async def test_raises_when_scorer_missing( - self, scoring_service, mock_memory, mock_registry - ) -> None: + async def test_raises_when_scorer_missing(self, scoring_service, mock_memory, mock_registry) -> None: mock_memory.get_attack_results.return_value = [_make_attack_result()] mock_registry.get.return_value = None @@ -207,9 +230,7 @@ async def test_raises_when_scorer_missing( request=ScoreConversationRequest(scorer_registry_name="missing-scorer"), ) - async def test_raises_when_conversation_empty( - self, scoring_service, mock_memory, mock_registry - ) -> None: + async def test_raises_when_conversation_empty(self, scoring_service, mock_memory, mock_registry) -> None: mock_memory.get_attack_results.return_value = [_make_attack_result()] mock_memory.get_conversation.return_value = [] mock_registry.get.return_value = MagicMock(spec=TrueFalseScorer) @@ -262,9 +283,7 @@ async def test_last_message_scores_most_recent_assistant_turn( assert len(result.scores) == 1 assert result.scores[0].score_value == "true" - async def test_whole_conversation_wraps_scorer( - self, scoring_service, mock_memory, mock_registry - ) -> None: + async def test_whole_conversation_wraps_scorer(self, scoring_service, mock_memory, mock_registry) -> None: mock_memory.get_attack_results.return_value = [_make_attack_result()] # Whole-conv mode just hands the last message to the wrapped scorer; content doesn't matter. last = _make_message([_make_piece(role="assistant")]) @@ -273,9 +292,7 @@ async def test_whole_conversation_wraps_scorer( scorer = MagicMock(spec=FloatScaleScorer) mock_registry.get.return_value = scorer - with patch( - "pyrit.score.conversation_scorer.create_conversation_scorer" - ) as mock_create: + with patch("pyrit.score.conversation_scorer.create_conversation_scorer") as mock_create: wrapped = MagicMock() wrapped.score_async = AsyncMock(return_value=[_make_pyrit_score()]) mock_create.return_value = wrapped @@ -283,9 +300,7 @@ async def test_whole_conversation_wraps_scorer( await scoring_service.score_conversation_async( attack_result_id="ar-1", conversation_id="conv-1", - request=ScoreConversationRequest( - scorer_registry_name="my-scorer", mode="whole_conversation" - ), + request=ScoreConversationRequest(scorer_registry_name="my-scorer", mode="whole_conversation"), ) mock_create.assert_called_once_with(scorer=scorer) @@ -302,9 +317,7 @@ async def test_whole_conversation_rejects_unsupported_scorer( await scoring_service.score_conversation_async( attack_result_id="ar-1", conversation_id="conv-1", - request=ScoreConversationRequest( - scorer_registry_name="my-scorer", mode="whole_conversation" - ), + request=ScoreConversationRequest(scorer_registry_name="my-scorer", mode="whole_conversation"), ) @@ -338,13 +351,9 @@ async def test_scores_specific_piece(self, scoring_service, mock_memory, mock_re assert scorer.score_async.await_args.kwargs["message"] is target_msg assert len(result.scores) == 1 - async def test_raises_when_piece_not_in_conversation( - self, scoring_service, mock_memory, mock_registry - ) -> None: + async def test_raises_when_piece_not_in_conversation(self, scoring_service, mock_memory, mock_registry) -> None: mock_memory.get_attack_results.return_value = [_make_attack_result()] - mock_memory.get_conversation.return_value = [ - _make_message([_make_piece(role="assistant", piece_id="other")]) - ] + mock_memory.get_conversation.return_value = [_make_message([_make_piece(role="assistant", piece_id="other")])] mock_registry.get.return_value = MagicMock(spec=TrueFalseScorer) with pytest.raises(LookupError, match="not part of conversation"): From 0f41bc3ab760af94b7de132d398212aeedb04ed2 Mon Sep 17 00:00:00 2001 From: jbolor21 <86250273+jbolor21@users.noreply.github.com> Date: Fri, 12 Jun 2026 10:34:19 -0700 Subject: [PATCH 3/3] adding custom scorer option POC --- frontend/src/components/Chat/ChatWindow.tsx | 4 +- .../src/components/Chat/ConversationPanel.tsx | 4 +- .../Chat/CustomScorerDialog.test.tsx | 299 ++++++++++ .../components/Chat/CustomScorerDialog.tsx | 511 ++++++++++++++++++ frontend/src/components/Chat/MessageList.tsx | 4 +- .../src/components/Chat/ScoreDialog.test.tsx | 189 +++++++ frontend/src/components/Chat/ScoreDialog.tsx | 189 ++++++- frontend/src/services/api.ts | 23 + frontend/src/types/index.ts | 52 ++ pyrit/backend/models/scoring.py | 125 ++++- pyrit/backend/routes/scoring.py | 97 +++- pyrit/backend/services/scoring_service.py | 279 ++++++++++ tests/unit/backend/test_scoring_service.py | 373 +++++++++++++ 13 files changed, 2129 insertions(+), 20 deletions(-) create mode 100644 frontend/src/components/Chat/CustomScorerDialog.test.tsx create mode 100644 frontend/src/components/Chat/CustomScorerDialog.tsx diff --git a/frontend/src/components/Chat/ChatWindow.tsx b/frontend/src/components/Chat/ChatWindow.tsx index c8205295e7..d2cbf57af2 100644 --- a/frontend/src/components/Chat/ChatWindow.tsx +++ b/frontend/src/components/Chat/ChatWindow.tsx @@ -4,7 +4,7 @@ import { Text, Tooltip, } from '@fluentui/react-components' -import { AddRegular, PanelRightRegular, ClipboardTaskRegular } from '@fluentui/react-icons' +import { AddRegular, PanelRightRegular, DataBarVerticalRegular } from '@fluentui/react-icons' import MessageList from './MessageList' import ChatInputArea from './ChatInputArea' import ConversationPanel from './ConversationPanel' @@ -606,7 +606,7 @@ export default function ChatWindow({ > + + + + + + ) +} + +// --------------------------------------------------------------------- // +// Per-kind subforms +// --------------------------------------------------------------------- // + +function FloatScaleFields({ + config, + onChange, +}: { + config: GeneralFloatScaleConfig + onChange: (c: GeneralFloatScaleConfig) => void +}) { + const rangeInvalid = config.max_value <= config.min_value + return ( + <> + +