Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions pyrit/score/conversation_scorer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import uuid
from abc import ABC, abstractmethod
from typing import cast
from uuid import UUID
from typing import TYPE_CHECKING, cast

from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score
from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer
from pyrit.score.scorer import Scorer
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator
from pyrit.score.true_false.true_false_scorer import TrueFalseScorer

if TYPE_CHECKING:
from uuid import UUID


class ConversationScorer(Scorer, ABC):
"""
Expand Down Expand Up @@ -44,6 +45,11 @@ async def _score_async(self, message: Message, *, objective: str | None = None)
conversation, even when the triggering turn was blocked or errored; the wrapped
scorer's fallback only fires when the rendered conversation is genuinely unscoreable.

The wrapped scorer is invoked via its protected ``_score_async`` so it does not
persist its own copy of the scores. The outer ``Scorer.score_async`` that invoked
this method persists the returned scores exactly once, keyed to the original
``message_piece_id``.

Args:
message (Message): A message from the conversation to be scored.
The conversation ID from the first message piece is used to retrieve the full conversation from memory.
Expand Down Expand Up @@ -118,14 +124,10 @@ async def _score_async(self, message: Message, *, objective: str | None = None)
)

wrapped_scorer = self._get_wrapped_scorer()
scores = await wrapped_scorer.score_async(message=conversation_message, objective=objective)

# Generate new IDs for the scores to avoid ID collisions when the wrapped scorer's
# scores are already in the database
for score in scores:
score.id = uuid.uuid4()

return scores
# Call the wrapped scorer's protected ``_score_async`` rather than the public
# ``score_async`` so the wrapped scorer does not persist its own copy of the
# scores.
return await wrapped_scorer._score_async(message=conversation_message, objective=objective)

async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]:
"""
Expand Down
57 changes: 33 additions & 24 deletions tests/unit/score/test_conversation_history_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def test_conversation_history_scorer_score_async_success(patch_central_dat
objective="test_objective",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)
Expand All @@ -150,8 +150,8 @@ async def test_conversation_history_scorer_score_async_success(patch_central_dat
assert result_score.score_rationale == "Valid rationale"

# Verify the underlying scorer was called with conversation history
mock_scorer.score_async.assert_awaited_once()
call_args = mock_scorer.score_async.call_args
mock_scorer._score_async.assert_awaited_once()
call_args = mock_scorer._score_async.call_args
called_message = call_args.kwargs["message"]
called_piece = called_message.message_pieces[0]

Expand Down Expand Up @@ -227,13 +227,13 @@ async def test_conversation_history_scorer_filters_roles_correctly(patch_central
objective="test",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)
await scorer.score_async(message)

call_args = mock_scorer.score_async.call_args
call_args = mock_scorer._score_async.call_args
called_message = call_args.kwargs["message"]
called_piece = called_message.message_pieces[0]

Expand Down Expand Up @@ -274,14 +274,14 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data
objective="test",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)

await scorer.score_async(message)

call_args = mock_scorer.score_async.call_args
call_args = mock_scorer._score_async.call_args
called_message = call_args.kwargs["message"]
called_piece = called_message.message_pieces[0]

Expand All @@ -292,8 +292,13 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data
assert called_piece.attack_identifier == message_piece.attack_identifier


async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(patch_central_database):
"""Test that ConversationScorer regenerates score IDs to prevent database UNIQUE constraint violations."""
async def test_conversation_scorer_persists_scores_exactly_once(patch_central_database):
"""ConversationScorer must not double-persist: one inner score → one ScoreEntry in memory.

Regression guard for the bug where ConversationScorer called the wrapped scorer's
public ``score_async`` (which persists) and then the outer ``Scorer.score_async`` also
persisted, producing two identical ``ScoreEntry`` rows per call.
"""
memory = CentralMemory.get_memory_instance()
conversation_id = str(uuid.uuid4())

Expand All @@ -305,7 +310,6 @@ async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(p
)
memory.add_message_pieces_to_memory(message_pieces=[message_piece])

# Create a score and capture its original ID
score = Score(
score_value="0.5",
score_value_description="Test",
Expand All @@ -319,22 +323,27 @@ async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(p
)
original_id = score.id

# Mock scorer returns the score (which will be mutated by ConversationScorer)
# Mock the protected _score_async; the public score_async (which persists) is intentionally
# NOT mocked so the test would fail with duplicate rows if ConversationScorer ever calls it.
mock_scorer = MagicMock(spec=SelfAskGeneralFloatScaleScorer)
mock_scorer._validator = ScorerPromptValidator(supported_data_types=["text"])
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

# Create conversation scorer and score the message
conv_scorer = create_conversation_scorer(scorer=mock_scorer)
message = MagicMock()
message.message_pieces = [message_piece]
result_scores = await conv_scorer.score_async(message)

# Verify that ConversationScorer regenerated the ID
assert len(result_scores) == 1
assert result_scores[0].id != original_id, "ConversationScorer should regenerate score IDs to prevent collisions"
assert isinstance(result_scores[0].id, uuid.UUID), "Regenerated ID should be a valid UUID"
assert result_scores[0].id == original_id, (
"ConversationScorer should preserve the inner scorer's score ID; only the outer "
"Scorer.score_async should persist, so no ID regeneration is needed."
)

persisted = list(memory.get_scores(score_type="float_scale"))
assert len(persisted) == 1, f"Expected exactly one ScoreEntry persisted; got {len(persisted)}"
assert persisted[0].id == original_id


def test_conversation_scorer_cannot_be_instantiated_directly():
Expand Down Expand Up @@ -539,7 +548,7 @@ async def test_conversation_scorer_uses_partial_content_when_score_blocked_conte
objective="test",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)
Expand All @@ -549,8 +558,8 @@ async def test_conversation_scorer_uses_partial_content_when_score_blocked_conte
assert len(scores) == 1

# Verify the underlying scorer was called with partial content, not error JSON
mock_scorer.score_async.assert_awaited_once()
call_args = mock_scorer.score_async.call_args
mock_scorer._score_async.assert_awaited_once()
call_args = mock_scorer._score_async.call_args
called_message = call_args.kwargs["message"]
called_piece = called_message.message_pieces[0]

Expand Down Expand Up @@ -611,7 +620,7 @@ async def test_conversation_scorer_uses_error_json_when_score_blocked_content_di
objective="test",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)
Expand All @@ -621,8 +630,8 @@ async def test_conversation_scorer_uses_error_json_when_score_blocked_content_di
assert len(scores) == 1

# Verify the underlying scorer was called with error JSON, not partial content
mock_scorer.score_async.assert_awaited_once()
call_args = mock_scorer.score_async.call_args
mock_scorer._score_async.assert_awaited_once()
call_args = mock_scorer._score_async.call_args
called_message = call_args.kwargs["message"]
called_piece = called_message.message_pieces[0]

Expand Down Expand Up @@ -678,7 +687,7 @@ async def test_conversation_scorer_blocked_input_message_does_not_raise(patch_ce
objective="test",
score_type="float_scale",
)
mock_scorer.score_async = AsyncMock(return_value=[score])
mock_scorer._score_async = AsyncMock(return_value=[score])
mock_scorer.validate_return_scores = MagicMock()

scorer = create_conversation_scorer(scorer=mock_scorer)
Expand All @@ -687,7 +696,7 @@ async def test_conversation_scorer_blocked_input_message_does_not_raise(patch_ce
scores = await scorer.score_async(blocked_message)

assert len(scores) == 1
mock_scorer.score_async.assert_awaited_once()
mock_scorer._score_async.assert_awaited_once()


async def test_conversation_scorer_blocked_trigger_preserves_prior_turn_scoring(patch_central_database):
Expand Down
Loading