From 72f02b62c6183f09200b7f7e641373f10c862128 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sat, 6 Jun 2026 20:38:39 -0700 Subject: [PATCH 01/12] Introduce Conversation model; move target identifier off MessagePiece Move prompt_target_identifier to a new Conversations table (hydrated on read) and remove attack_identifier from MessagePiece (it now lives only on AttackResult.atomic_attack_identifier). The deprecated attack_id query filter resolves via get_attack_results() to the attack's main conversation. Adds the Conversation model and an alembic migration for the Conversations table, and updates all call sites and unit tests accordingly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../attack/component/conversation_manager.py | 8 - .../attack/multi_turn/chunked_request.py | 4 - pyrit/executor/attack/multi_turn/crescendo.py | 8 - .../attack/multi_turn/multi_prompt_sending.py | 4 - .../executor/attack/multi_turn/red_teaming.py | 9 +- .../attack/multi_turn/tree_of_attacks.py | 9 - .../attack/single_turn/context_compliance.py | 3 - .../attack/single_turn/prompt_sending.py | 4 - pyrit/executor/attack/streaming/barge_in.py | 2 - pyrit/executor/benchmark/fairness_bias.py | 2 - pyrit/executor/promptgen/anecdoctor.py | 4 - pyrit/executor/promptgen/fuzzer/fuzzer.py | 1 - .../promptgen/fuzzer/fuzzer_converter_base.py | 1 - .../fuzzer/fuzzer_crossover_converter.py | 1 - .../fuzzer/fuzzer_expand_converter.py | 1 - pyrit/executor/workflow/xpia.py | 3 - .../b2f4c6a8d1e3_add_conversations_table.py | 142 ++++++++++++ pyrit/memory/azure_sql_memory.py | 11 +- pyrit/memory/memory_interface.py | 178 +++++++++++++-- pyrit/memory/memory_models.py | 79 +++++-- pyrit/memory/sqlite_memory.py | 14 +- pyrit/models/__init__.py | 2 + pyrit/models/messages/__init__.py | 2 + pyrit/models/messages/conversation.py | 31 +++ pyrit/models/messages/conversations.py | 1 - pyrit/models/messages/message_piece.py | 4 +- .../llm_generic_text_converter.py | 1 - pyrit/prompt_normalizer/prompt_normalizer.py | 24 +- pyrit/prompt_target/common/prompt_target.py | 9 +- .../_openai_realtime_streaming_session.py | 6 - .../openai/openai_realtime_target.py | 12 +- .../openai/openai_response_target.py | 2 - .../playwright_copilot_target.py | 1 - pyrit/score/conversation_scorer.py | 1 - pyrit/score/float_scale/float_scale_scorer.py | 4 +- .../score/float_scale/insecure_code_scorer.py | 1 - .../self_ask_general_float_scale_scorer.py | 1 - .../float_scale/self_ask_likert_scorer.py | 1 - .../float_scale/self_ask_scale_scorer.py | 1 - pyrit/score/scorer.py | 5 - .../true_false/self_ask_category_scorer.py | 1 - .../self_ask_general_true_false_scorer.py | 1 - .../self_ask_question_answer_scorer.py | 1 - .../true_false/self_ask_refusal_scorer.py | 1 - .../true_false/self_ask_true_false_scorer.py | 1 - .../component/test_conversation_manager.py | 118 ++++------ .../single_turn/test_context_compliance.py | 3 - .../attack/single_turn/test_prompt_sending.py | 1 - .../attack/single_turn/test_skeleton_key.py | 1 - .../attack/streaming/test_barge_in.py | 1 - .../promptgen/fuzzer/test_fuzzer_converter.py | 1 - .../memory_interface/test_batching_scale.py | 1 - .../memory_interface/test_interface_export.py | 2 +- .../test_interface_prompts.py | 211 ++++-------------- .../memory_interface/test_interface_scores.py | 27 ++- tests/unit/memory/test_azure_sql_memory.py | 76 +++---- tests/unit/memory/test_memory_models.py | 18 -- tests/unit/memory/test_sqlite_memory.py | 67 +++--- tests/unit/mocks.py | 6 - tests/unit/models/test_attack_result.py | 1 - tests/unit/models/test_message_piece.py | 29 +-- .../test_persuasion_converter.py | 1 - .../test_variation_converter.py | 1 - .../test_prompt_normalizer.py | 3 - .../test_normalize_async_integration.py | 1 - .../target/test_openai_chat_target.py | 6 - .../test_openai_realtime_streaming_session.py | 72 +----- .../target/test_openai_response_target.py | 6 - .../target/test_prompt_target.py | 10 - .../score/test_conversation_history_scorer.py | 2 - tests/unit/score/test_scorer.py | 65 ------ 71 files changed, 628 insertions(+), 704 deletions(-) create mode 100644 pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py create mode 100644 pyrit/models/messages/conversation.py diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index b4b47175f2..86127c765a 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -54,7 +54,6 @@ def get_adversarial_chat_messages( prepended_conversation: list[Message], *, adversarial_chat_conversation_id: str, - attack_identifier: ComponentIdentifier, adversarial_chat_target_identifier: ComponentIdentifier, labels: dict[str, str] | None = None, # deprecated ) -> list[Message]: @@ -72,7 +71,6 @@ def get_adversarial_chat_messages( Args: prepended_conversation: The original conversation messages to transform. adversarial_chat_conversation_id: Conversation ID for the adversarial chat. - attack_identifier (ComponentIdentifier): Attack identifier to associate with messages. adversarial_chat_target_identifier (ComponentIdentifier): Target identifier for the adversarial chat. labels: Optional labels to associate with the messages. Deprecated: This parameter will be removed in a release 0.16.0. @@ -114,7 +112,6 @@ def get_adversarial_chat_messages( original_value_data_type=piece.original_value_data_type, converted_value_data_type=piece.converted_value_data_type, conversation_id=adversarial_chat_conversation_id, - attack_identifier=attack_identifier, prompt_target_identifier=adversarial_chat_target_identifier, labels=labels or {}, # deprecated ) @@ -190,20 +187,17 @@ class ConversationManager: def __init__( self, *, - attack_identifier: ComponentIdentifier, prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the conversation manager. Args: - attack_identifier (ComponentIdentifier): The identifier of the attack this manager belongs to. prompt_normalizer: Optional prompt normalizer for converting prompts. If not provided, a default PromptNormalizer instance will be created. """ self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() - self._attack_identifier = attack_identifier def get_conversation(self, conversation_id: str) -> list[Message]: """ @@ -276,7 +270,6 @@ def set_system_prompt( target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=self._attack_identifier, labels=labels, # deprecated ) @@ -485,7 +478,6 @@ async def add_prepended_conversation_to_memory_async( for piece in message_copy.message_pieces: piece.conversation_id = conversation_id - piece.attack_identifier = self._attack_identifier # Count turns at message level (only assistant/simulated_assistant messages) # A multi-part response still counts as one turn diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 4e7e5caefa..b7b04ec10e 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -172,7 +172,6 @@ def __init__( # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -279,7 +278,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -291,7 +289,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) # Store the response @@ -377,7 +374,6 @@ async def _score_combined_value_async( with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective=objective, ): diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index bc987f1270..8a2e6dd730 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -232,7 +232,6 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -331,7 +330,6 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: self._adversarial_chat.set_system_prompt( system_prompt=system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -545,7 +543,6 @@ async def _send_prompt_to_adversarial_chat_async( with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -554,7 +551,6 @@ async def _send_prompt_to_adversarial_chat_async( message=message, conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -649,7 +645,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -660,7 +655,6 @@ async def _send_prompt_to_objective_target_async( conversation_id=context.session.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -689,7 +683,6 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective: with execution_context( component_role=ComponentRole.REFUSAL_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._refusal_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -721,7 +714,6 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index e15ef6c63d..cc4c53531d 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -175,7 +175,6 @@ def __init__( # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -355,7 +354,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -367,7 +365,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, # combined with strategy labels at _setup() - attack_identifier=self.get_identifier(), ) async def _evaluate_response_async(self, *, response: Message, objective: str) -> Score | None: @@ -389,7 +386,6 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) - with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, objective=objective, ): diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 8c0d34c6eb..def0bd0113 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -166,7 +166,7 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() - self._conversation_manager = ConversationManager(attack_identifier=self.get_identifier()) + self._conversation_manager = ConversationManager() # set the maximum number of turns for the attack if max_turns <= 0: @@ -260,7 +260,6 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: self._adversarial_chat.set_system_prompt( system_prompt=adversarial_system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -270,7 +269,6 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: adversarial_messages = get_adversarial_chat_messages( prepended_conversation=context.prepended_conversation, adversarial_chat_conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), adversarial_chat_target_identifier=self._adversarial_chat.get_identifier(), labels=context.memory_labels, ) @@ -388,7 +386,6 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -397,7 +394,6 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] message=prompt_message, conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -550,7 +546,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -563,7 +558,6 @@ async def _send_prompt_to_objective_target_async( response_converter_configurations=self._response_converters, target=self._objective_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) if response is None: @@ -598,7 +592,6 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 0cd557b1c6..0e40119c98 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -404,7 +404,6 @@ async def initialize_with_prepended_conversation_async( # Use ConversationManager to add messages to memory conversation_manager = ConversationManager( - attack_identifier=self._attack_id, prompt_normalizer=self._prompt_normalizer, ) @@ -558,7 +557,6 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -570,7 +568,6 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: conversation_id=self.objective_target_conversation_id, target=self._objective_target, labels=self._memory_labels, - attack_identifier=self._attack_id, ) # Store the last response text for reference @@ -618,7 +615,6 @@ async def _send_initial_prompt_to_target_async(self) -> Message: with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -630,7 +626,6 @@ async def _send_initial_prompt_to_target_async(self) -> Message: conversation_id=self.objective_target_conversation_id, target=self._objective_target, labels=self._memory_labels, - attack_identifier=self._attack_id, ) # Store the last response text for reference @@ -675,7 +670,6 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=objective, @@ -1021,7 +1015,6 @@ async def _generate_first_turn_prompt_async(self, objective: str) -> str: self._adversarial_chat.set_system_prompt( system_prompt=system_prompt, conversation_id=self.adversarial_chat_conversation_id, - attack_identifier=self._attack_id, labels=self._memory_labels, # deprecated ) @@ -1138,7 +1131,6 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -1148,7 +1140,6 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: conversation_id=self.adversarial_chat_conversation_id, target=self._adversarial_chat, labels=self._memory_labels, - attack_identifier=self._attack_id, ) return response.get_value() diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index 4568a158e8..98b0292541 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -234,7 +234,6 @@ async def _get_objective_as_benign_question_async( response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -261,7 +260,6 @@ async def _get_benign_question_answer_async( response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -286,7 +284,6 @@ async def _get_objective_as_question_async(self, *, objective: str, context: Sin response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 794cc4294e..32e4db677b 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -106,7 +106,6 @@ def __init__( # Skip criteria could be set directly in the injected prompt normalizer self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -314,7 +313,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.conversation_id, objective=context.params.objective, @@ -326,7 +324,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, # combined with strategy labels at _setup() - attack_identifier=self.get_identifier(), ) async def _evaluate_response_async( @@ -353,7 +350,6 @@ async def _evaluate_response_async( with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, objective=objective, ): diff --git a/pyrit/executor/attack/streaming/barge_in.py b/pyrit/executor/attack/streaming/barge_in.py index 161e364f85..2c96fbdb62 100644 --- a/pyrit/executor/attack/streaming/barge_in.py +++ b/pyrit/executor/attack/streaming/barge_in.py @@ -99,7 +99,6 @@ def __init__( self._response_converters = attack_converter_config.response_converters self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -163,7 +162,6 @@ async def _perform_async(self, *, context: BargeInAttackContext[Any]) -> AttackR request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, prepended_conversation=context.prepended_conversation, - attack_identifier=self.get_identifier(), persist_prepended_conversation=False, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 4e3bfaa505..4ab0fe432a 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -21,7 +21,6 @@ from pyrit.models import ( AttackOutcome, AttackResult, - ComponentIdentifier, Message, build_atomic_attack_identifier, ) @@ -198,7 +197,6 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta objective=context.generated_objective, outcome=AttackOutcome.FAILURE, atomic_attack_identifier=build_atomic_attack_identifier( - attack_identifier=ComponentIdentifier.of(self), ), labels=context.memory_labels, ) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 0d953fe8ad..abc8efbd69 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -218,7 +218,6 @@ async def _setup_async(self, *, context: AnecdoctorContext) -> None: self._objective_target.set_system_prompt( system_prompt=system_prompt, conversation_id=context.conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -312,7 +311,6 @@ async def _send_examples_to_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) def _load_prompt_from_yaml(self, *, yaml_filename: str) -> str: @@ -381,7 +379,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> self._processing_model.set_system_prompt( system_prompt=kg_system_prompt, conversation_id=kg_conversation_id, - attack_identifier=self.get_identifier(), labels=self._memory_labels, # deprecated ) @@ -399,7 +396,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=self._memory_labels, - attack_identifier=self.get_identifier(), ) if not kg_response: diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 73410799a4..492ae2ee62 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -1003,7 +1003,6 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts requests=requests, target=self._objective_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), batch_size=self._batch_size, ) diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py index 5b211487cc..dd9519bdf4 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py @@ -84,7 +84,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 812979a797..2f7ddf5be4 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -82,7 +82,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index 627ed159ed..e8fae2fa18 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -58,7 +58,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index c367cd9b02..80c0fdefd0 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -335,7 +335,6 @@ async def _setup_attack_async(self, *, context: XPIAContext) -> str: response_converter_configurations=self._response_converters, target=self._attack_setup_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), conversation_id=context.attack_setup_target_conversation_id, ) @@ -374,7 +373,6 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: original_value=processing_response, original_value_data_type="text", role="assistant", - attack_identifier=self.get_identifier(), ) ], ) @@ -576,7 +574,6 @@ async def process_async() -> str: request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), conversation_id=context.processing_conversation_id, ) diff --git a/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py b/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py new file mode 100644 index 0000000000..8b8d93f8ca --- /dev/null +++ b/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Introduce the Conversations table for conversation-scoped metadata and stop +stamping that metadata onto every PromptMemoryEntry row. + +Creates ``Conversations`` (one row per ``conversation_id``) holding the target +identifier, backfills it from the existing +``PromptMemoryEntries.prompt_target_identifier`` column (plus placeholder rows for +conversation_ids referenced only by ``AttackResultEntries``), and drops the now +per-row ``prompt_target_identifier`` and ``attack_identifier`` columns from +``PromptMemoryEntries``. + +Revision ID: b2f4c6a8d1e3 +Revises: 9c8b7a6d5e4f +Create Date: 2026-05-20 12:00:00.000000 +""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence # noqa: TC003 + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b2f4c6a8d1e3" +down_revision: str | None = "9c8b7a6d5e4f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +logger = logging.getLogger(__name__) + + +def upgrade() -> None: + """Apply this schema upgrade.""" + op.create_table( + "Conversations", + sa.Column("conversation_id", sa.String(), primary_key=True, nullable=False), + sa.Column("target_identifier", sa.JSON(), nullable=True), + sa.Column("pyrit_version", sa.String(), nullable=True), + ) + + _backfill_conversations() + + # Stop persisting conversation-scoped metadata per row: the target identifier now + # lives in Conversations, and the attack identifier is no longer stamped on pieces + # (resolved via AttackResult). Batch op for SQLite portability. + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.drop_column("prompt_target_identifier") + batch_op.drop_column("attack_identifier") + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # Re-add the dropped columns (data is not restored) then drop Conversations. + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.add_column(sa.Column("prompt_target_identifier", sa.JSON(), nullable=True)) + batch_op.add_column(sa.Column("attack_identifier", sa.JSON(), nullable=True)) + op.drop_table("Conversations") + + +def _backfill_conversations() -> None: + """ + Populate ``Conversations`` with one row per distinct ``conversation_id``. + + The target identifier is taken from the existing + ``PromptMemoryEntries.prompt_target_identifier`` column, preferring a non-null + value when a conversation has rows with differing targets (a non-null target + always wins over null; a WARNING is logged if two distinct non-null targets are + seen for the same conversation). Conversation ids that are referenced only by + ``AttackResultEntries`` (no prompt rows) get a placeholder row with a null + target so reads/joins stay consistent. + + Idempotent: only conversation_ids not already present in ``Conversations`` are + inserted. + """ + bind = op.get_bind() + + existing_ids = { + row[0] for row in bind.execute(sa.text('SELECT conversation_id FROM "Conversations"')).fetchall() + } + + targets_by_conversation: dict[str, str | None] = {} + conflict_warnings = 0 + + prompt_rows = bind.execute( + sa.text( + 'SELECT conversation_id, prompt_target_identifier ' + 'FROM "PromptMemoryEntries" ' + "WHERE conversation_id IS NOT NULL " + "ORDER BY sequence" + ) + ).fetchall() + + for conversation_id, target_identifier in prompt_rows: + if conversation_id is None: + continue + current = targets_by_conversation.get(conversation_id, "__unset__") + if current == "__unset__": + targets_by_conversation[conversation_id] = target_identifier + elif target_identifier is not None: + if current is None: + targets_by_conversation[conversation_id] = target_identifier + elif current != target_identifier: + conflict_warnings += 1 + logger.warning( + f"Backfill: conversation_id {conversation_id!r} has multiple distinct " + f"target identifiers; keeping the first non-null value." + ) + + # Conversation ids referenced only by AttackResultEntries (no prompt rows). + attack_rows = bind.execute( + sa.text('SELECT DISTINCT conversation_id FROM "AttackResultEntries" WHERE conversation_id IS NOT NULL') + ).fetchall() + for (conversation_id,) in attack_rows: + if conversation_id is not None and conversation_id not in targets_by_conversation: + targets_by_conversation[conversation_id] = None + + insert_stmt = sa.text( + 'INSERT INTO "Conversations" (conversation_id, target_identifier, pyrit_version) ' + "VALUES (:cid, :target, :version)" + ) + + inserted = 0 + for conversation_id, target_identifier in targets_by_conversation.items(): + if conversation_id in existing_ids: + continue + bind.execute( + insert_stmt, + {"cid": conversation_id, "target": target_identifier, "version": None}, + ) + inserted += 1 + + if inserted or conflict_warnings: + logger.info( + f"Conversations backfill: inserted {inserted} row(s); " + f"{conflict_warnings} target-conflict warning(s)." + ) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 81b97d8716..2dfc3b183d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -718,6 +718,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return + self._capture_conversations(message_pieces=pieces_to_insert) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def dispose_engine(self) -> None: @@ -827,12 +828,20 @@ def _query_entries( try: query = session.query(model_class) if join_scores and model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.scores)) + query = query.options( + joinedload(PromptMemoryEntry.scores), + joinedload(PromptMemoryEntry.conversation_metadata), + ) elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), + joinedload(AttackResultEntry.last_response).joinedload( + PromptMemoryEntry.conversation_metadata + ), joinedload(AttackResultEntry.last_score), ) + elif model_class == PromptMemoryEntry: + query = query.options(joinedload(PromptMemoryEntry.conversation_metadata)) if conditions is not None: query = query.filter(conditions) if order_by is not None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index f0e5fdeb9b..e062185a2d 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar -from sqlalchemy import MetaData, and_, not_, or_ +from sqlalchemy import MetaData, and_, not_, or_, select from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -27,6 +27,7 @@ from pyrit.memory.memory_models import ( AttackResultEntry, Base, + ConversationEntry, EmbeddingDataEntry, PromptMemoryEntry, ScenarioResultEntry, @@ -35,6 +36,8 @@ ) from pyrit.models import ( AttackResult, + ComponentIdentifier, + Conversation, ConversationStats, DataTypeSerializer, IdentifierFilter, @@ -349,6 +352,66 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] Insert a list of message pieces into the memory storage. """ + def _capture_conversations(self, *, message_pieces: Sequence[MessagePiece]) -> None: + """ + Record one ``Conversations`` row per conversation for the given pieces. + + Conversation-scoped metadata (currently the target identifier) is persisted + once per ``conversation_id`` instead of being stamped onto every piece. This + runs from each backend's ``add_message_pieces_to_memory`` so every write path + -- normalizer, conversation duplication, prepended conversations, direct + target writers -- captures the target through a single choke point. + + Args: + message_pieces (Sequence[MessagePiece]): The pieces being persisted. + """ + targets_by_conversation: dict[str, ComponentIdentifier | None] = {} + for piece in message_pieces: + if piece.not_in_memory: + continue + conversation_id = piece.conversation_id + if targets_by_conversation.get(conversation_id) is None: + targets_by_conversation[conversation_id] = piece.prompt_target_identifier + for conversation_id, target_identifier in targets_by_conversation.items(): + self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) + + def _upsert_conversation( + self, *, conversation_id: str, target_identifier: ComponentIdentifier | None + ) -> None: + """ + Insert or update the ``Conversations`` row for ``conversation_id``. + + A non-``None`` ``target_identifier`` is written; a ``None`` value never + overwrites a target already recorded for the conversation (so response/copy + pieces and write ordering cannot clobber it). + + Args: + conversation_id (str): The conversation to record. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. + + Raises: + SQLAlchemyError: If the upsert fails. + """ + if not conversation_id: + return + entry = ConversationEntry( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) + with closing(self.get_session()) as session: + try: + existing = session.get(ConversationEntry, conversation_id) + if existing is None: + session.add(entry) + elif target_identifier is not None: + existing.target_identifier = entry.target_identifier + existing.pyrit_version = entry.pyrit_version + session.commit() + except SQLAlchemyError as e: + session.rollback() + logger.exception(f"Error upserting conversation {conversation_id}: {e}") + raise + @abc.abstractmethod def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ @@ -853,6 +916,25 @@ def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: message_pieces = self.get_message_pieces(conversation_id=conversation_id) return group_conversation_message_pieces_by_sequence(message_pieces=message_pieces) + def get_conversation_metadata(self, *, conversation_id: str) -> Conversation | None: + """ + Return the conversation-scoped metadata stored for ``conversation_id``. + + Args: + conversation_id (str): The conversation to look up. + + Returns: + Conversation | None: The conversation metadata (including the target + identifier), or ``None`` if no row exists for the conversation. + """ + entries = self._query_entries( + ConversationEntry, + conditions=ConversationEntry.conversation_id == str(conversation_id), + ) + if not entries: + return None + return entries[0].get_conversation() + def get_request_from_response(self, *, response: Message) -> Message: """ Retrieve the request that produced the given response. @@ -874,6 +956,80 @@ def get_request_from_response(self, *, response: Message) -> Message: conversation = self.get_conversation(conversation_id=response.conversation_id) return conversation[response.sequence - 1] + def _resolve_attack_id_to_conversation_condition(self, *, attack_id: str | uuid.UUID) -> Any: + """ + Build a deprecated ``attack_id`` filter condition for ``get_message_pieces``. + + The attack identifier is no longer stamped on every piece. Instead, resolve the + raw attack-strategy hash against persisted ``AttackResult`` rows and constrain + the query to those attacks' main conversations. + + Args: + attack_id (str | uuid.UUID): The raw attack-strategy identifier hash. + + Returns: + Any: A SQLAlchemy condition restricting pieces to the matching attacks' + main conversation ids (matches nothing when no attack matches). + """ + print_deprecation_message( + old_item="get_message_pieces(attack_id=...) / get_prompt_scores(attack_id=...)", + new_item="get_message_pieces(conversation_id=...) resolved via get_attack_results(...)", + removed_in="0.17.0", + ) + matching_conversation_ids = { + result.conversation_id + for result in self.get_attack_results() + if (strategy := result.get_attack_strategy_identifier()) is not None and strategy.hash == str(attack_id) + } + return PromptMemoryEntry.conversation_id.in_(matching_conversation_ids) + + def _build_message_piece_identifier_conditions( + self, *, identifier_filters: Sequence[IdentifierFilter] + ) -> list[Any]: + """ + Build ``get_message_pieces`` conditions for identifier filters. + + ``CONVERTER`` identifiers remain on the piece. ``TARGET`` identifiers moved to + the ``Conversations`` table, so target filters are applied via a subquery on + ``ConversationEntry`` correlated by ``conversation_id``. ``ATTACK`` identifiers + are no longer stamped on pieces (use ``get_attack_results`` instead) and are + rejected by ``_build_identifier_filter_conditions``. + + Args: + identifier_filters (Sequence[IdentifierFilter]): The filters to convert. + + Returns: + list[Any]: SQLAlchemy conditions for the message-piece query. + """ + conditions: list[Any] = [] + piece_filters = [f for f in identifier_filters if f.identifier_type != IdentifierType.TARGET] + target_filters = [f for f in identifier_filters if f.identifier_type == IdentifierType.TARGET] + + if piece_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=piece_filters, + identifier_column_map={ + IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, + }, + caller="get_message_pieces", + ) + ) + if target_filters: + target_conditions = self._build_identifier_filter_conditions( + identifier_filters=target_filters, + identifier_column_map={ + IdentifierType.TARGET: ConversationEntry.target_identifier, + }, + caller="get_message_pieces", + ) + conditions.append( + PromptMemoryEntry.conversation_id.in_( + select(ConversationEntry.conversation_id).where(and_(*target_conditions)) + ) + ) + return conditions + def get_message_pieces( self, *, @@ -929,13 +1085,7 @@ def get_message_pieces( try: conditions: list[Any] = [] if attack_id: - conditions.append( - self._get_condition_json_property_match( - json_column=PromptMemoryEntry.attack_identifier, - property_path="$.hash", - value=str(attack_id), - ) - ) + conditions.append(self._resolve_attack_id_to_conversation_condition(attack_id=attack_id)) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -953,17 +1103,7 @@ def get_message_pieces( if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if identifier_filters: - conditions.extend( - self._build_identifier_filter_conditions( - identifier_filters=identifier_filters, - identifier_column_map={ - IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, - IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, - IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, - }, - caller="get_message_pieces", - ) - ) + conditions.extend(self._build_message_piece_identifier_conditions(identifier_filters=identifier_filters)) # Identify list parameters that may need batching list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 9f52b38afa..d31012cce9 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import ( DeclarativeBase, Mapped, + foreign, mapped_column, relationship, ) @@ -37,6 +38,7 @@ AttackResult, ChatMessageRole, ComponentIdentifier, + Conversation, ConversationReference, ConversationType, MessagePiece, @@ -239,7 +241,6 @@ class PromptMemoryEntry(Base): e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. - attack_identifier (dict[str, str]): The attack identifier for the prompt. original_value_data_type (PromptDataType): The data type of the original prompt (text, image) original_value (str): The text of the original prompt. If prompt is an image, it's a link. original_value_sha256 (str): The SHA256 hash of the original prompt data. @@ -267,8 +268,6 @@ class PromptMemoryEntry(Base): prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON) targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON) converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON) - prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) - attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) original_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) @@ -294,6 +293,18 @@ class PromptMemoryEntry(Base): foreign_keys="ScoreEntry.prompt_request_response_id", ) + # Conversation-scoped metadata (e.g. the target identifier) lives in the + # ``Conversations`` table keyed by ``conversation_id`` rather than on every row. + # ``viewonly`` because this join is read-only (there is no FK constraint); reads + # eager-load it via ``joinedload`` so detached entries can still hydrate the + # target onto the reconstructed ``MessagePiece``. + conversation_metadata: Mapped["ConversationEntry | None"] = relationship( + "ConversationEntry", + primaryjoin=lambda: foreign(PromptMemoryEntry.conversation_id) == ConversationEntry.conversation_id, + viewonly=True, + uselist=False, + ) + def __init__(self, *, entry: MessagePiece) -> None: """ Initialize a PromptMemoryEntry from a MessagePiece. @@ -310,8 +321,6 @@ def __init__(self, *, entry: MessagePiece) -> None: self.prompt_metadata = entry.prompt_metadata self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) - self.prompt_target_identifier = _dump_identifier(entry.prompt_target_identifier) or {} - self.attack_identifier = _dump_identifier(entry.attack_identifier) or {} self.original_value = entry.original_value self.original_value_data_type = entry.original_value_data_type @@ -336,8 +345,6 @@ def get_message_piece(self) -> MessagePiece: # Reconstruct ComponentIdentifiers with the stored pyrit_version stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION converter_ids = _load_identifiers(self.converter_identifiers, pyrit_version=stored_version) - target_id = _load_identifier(self.prompt_target_identifier, pyrit_version=stored_version) - attack_id = _load_identifier(self.attack_identifier, pyrit_version=stored_version) message_piece = MessagePiece( role=self.role, @@ -350,8 +357,6 @@ def get_message_piece(self) -> MessagePiece: sequence=self.sequence, prompt_metadata=self.prompt_metadata, converter_identifiers=converter_ids or [], - prompt_target_identifier=target_id, - attack_identifier=attack_id, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, response_error=self.response_error, @@ -365,6 +370,11 @@ def get_message_piece(self) -> MessagePiece: message_piece.labels = self.labels or {} message_piece.targeted_harm_categories = self.targeted_harm_categories or [] message_piece.scores = [score.get_score() for score in self.scores] + # The target identifier is conversation-scoped: hydrate it from the + # ``Conversations`` row (eager-loaded via ``conversation_metadata``) so it is + # served once per conversation rather than stored on every piece. + if self.conversation_metadata is not None: + message_piece.prompt_target_identifier = self.conversation_metadata.get_conversation().target_identifier return message_piece def __str__(self) -> str: @@ -374,13 +384,50 @@ def __str__(self) -> str: Returns: str: Formatted string representation of the memory entry. """ - if self.prompt_target_identifier: - # prompt_target_identifier is stored as dict in the database - class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get( - "__type__", "Unknown" - ) - return f"{class_name}: {self.role}: {self.converted_value}" - return f": {self.role}: {self.converted_value}" + return f"{self.role}: {self.converted_value}" + + +class ConversationEntry(Base): + """ + Conversation-scoped metadata, persisted once per ``conversation_id``. + + Holds identifiers that belong to the conversation as a whole -- currently the + target identifier -- so they are not duplicated onto every ``PromptMemoryEntry`` + row. The target is captured once when the conversation's pieces are written and + rehydrated onto pieces on read. + """ + + __tablename__ = "Conversations" + __table_args__ = {"extend_existing": True} + + conversation_id = mapped_column(String, primary_key=True, nullable=False) + target_identifier: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) + + # Version of PyRIT used when this entry was created. Nullable for backwards + # compatibility with existing databases. + pyrit_version = mapped_column(String, nullable=True) + + def __init__(self, *, conversation: Conversation) -> None: + """ + Initialize a ConversationEntry from a Conversation model. + + Args: + conversation (Conversation): The conversation metadata to persist. + """ + self.conversation_id = conversation.conversation_id + self.target_identifier = _dump_identifier(conversation.target_identifier) + self.pyrit_version = pyrit.__version__ + + def get_conversation(self) -> Conversation: + """ + Convert this database entry back into a Conversation model. + + Returns: + Conversation: The reconstructed conversation metadata. + """ + stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION + target_id = _load_identifier(self.target_identifier, pyrit_version=stored_version) + return Conversation(conversation_id=self.conversation_id, target_identifier=target_id) class EmbeddingDataEntry(Base): diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 5f628ab075..ce961f2fbb 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -312,6 +312,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return + self._capture_conversations(message_pieces=pieces_to_insert) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: @@ -361,12 +362,21 @@ def _query_entries( try: query = session.query(model_class) if join_scores and model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.scores)) + query = query.options( + joinedload(PromptMemoryEntry.scores), + joinedload(PromptMemoryEntry.conversation_metadata), + ) elif model_class == AttackResultEntry: query = query.options( - joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), + joinedload(AttackResultEntry.last_response) + .joinedload(PromptMemoryEntry.scores), + joinedload(AttackResultEntry.last_response).joinedload( + PromptMemoryEntry.conversation_metadata + ), joinedload(AttackResultEntry.last_score), ) + elif model_class == PromptMemoryEntry: + query = query.options(joinedload(PromptMemoryEntry.conversation_metadata)) if conditions is not None: query = query.filter(conditions) if order_by is not None: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 2a9fb9aec1..4323e447cb 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -70,6 +70,7 @@ SeedType, ) from pyrit.models.messages import ( + Conversation, Message, MessagePiece, construct_response_from_request, @@ -126,6 +127,7 @@ "ComponentIdentifier", "compute_eval_hash", "config_hash", + "Conversation", "ConversationReference", "ConversationStats", "ConversationType", diff --git a/pyrit/models/messages/__init__.py b/pyrit/models/messages/__init__.py index fca91f47ba..58c9f1a63e 100644 --- a/pyrit/models/messages/__init__.py +++ b/pyrit/models/messages/__init__.py @@ -9,6 +9,7 @@ - conversations: Free functions that operate on collections of messages/pieces. """ +from pyrit.models.messages.conversation import Conversation from pyrit.models.messages.conversations import ( construct_response_from_request, flatten_to_message_pieces, @@ -20,6 +21,7 @@ from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces __all__ = [ + "Conversation", "Message", "MessagePiece", "construct_response_from_request", diff --git a/pyrit/models/messages/conversation.py b/pyrit/models/messages/conversation.py new file mode 100644 index 0000000000..f5b8d956fb --- /dev/null +++ b/pyrit/models/messages/conversation.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + +from pyrit.models.score import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ComponentIdentifierField, +) + + +class Conversation(BaseModel): + """ + Conversation-scoped metadata shared by every piece in a conversation. + + A ``Conversation`` records identifiers that belong to the conversation as a + whole rather than to any individual ``MessagePiece`` -- most importantly the + target the conversation is held with. Persisting these once per conversation + (instead of stamping them onto every piece/row) is what keeps ``MessagePiece`` + small. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + conversation_id: str + target_identifier: ComponentIdentifierField | None = None diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index b225e527b2..32bbf0f0be 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -206,7 +206,6 @@ def construct_response_from_request( conversation_id=request.conversation_id, labels=request.labels, prompt_target_identifier=request.prompt_target_identifier, - attack_identifier=request.attack_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata or {}, diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 728f736f20..e9a25a3121 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -114,7 +114,6 @@ class MessagePiece(BaseModel): prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) prompt_target_identifier: ComponentIdentifierField | None = None - attack_identifier: ComponentIdentifierField | None = None scorer_identifier: ComponentIdentifierField | None = None scores: list[Score] = Field(default_factory=list) @@ -220,7 +219,7 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: Copy lineage metadata from ``source`` onto this piece. Lineage fields are the metadata that tie a piece back to its originating - conversation, attack, and target. Mutable containers (``labels``, + conversation and target. Mutable containers (``labels``, ``prompt_metadata``) are shallow-copied so that mutations on one piece do not affect others. @@ -229,7 +228,6 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: """ self.conversation_id = source.conversation_id self.labels = dict(source.labels) - self.attack_identifier = source.attack_identifier self.prompt_target_identifier = source.prompt_target_identifier self.prompt_metadata = dict(source.prompt_metadata) diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index e1422d68e7..ae7f763ab8 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -157,7 +157,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self._converter_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) converted_prompt = prompt diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index c089930778..9df030fd1a 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -88,7 +88,7 @@ async def send_prompt_async( labels (dict[str, str] | None, optional): Labels associated with the request. Defaults to None. Deprecated: This parameter will be removed in a release 0.16.0. attack_identifier (ComponentIdentifier | None, optional): Identifier for the attack. Defaults to - None. + None. Deprecated: this parameter is ignored and will be removed in release 0.17.0. Returns: Message: The response received from the target. @@ -103,6 +103,12 @@ async def send_prompt_async( new_item="send_prompt_async(...)", removed_in="0.16.0", ) + if attack_identifier is not None: + print_deprecation_message( + old_item="send_prompt_async(..., attack_identifier=...)", + new_item="send_prompt_async(...)", + removed_in="0.17.0", + ) # Validates that the MessagePieces in the Message are part of the same sequence request_converter_configurations = request_converter_configurations or [] response_converter_configurations = response_converter_configurations or [] @@ -118,8 +124,6 @@ async def send_prompt_async( if labels: piece.labels = labels # deprecated piece.prompt_target_identifier = target.get_identifier() - if attack_identifier: - piece.attack_identifier = attack_identifier # Apply request converters await self.convert_values_async(converter_configurations=request_converter_configurations, message=request) @@ -209,7 +213,7 @@ async def send_prompt_batch_to_target_async( labels (dict[str, str] | None, optional): A dictionary of labels to be included with the request. Defaults to None. attack_identifier (ComponentIdentifier | None, optional): The attack identifier. - Defaults to None. + Defaults to None. Deprecated: this parameter is ignored and will be removed in release 0.17.0. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. Returns: @@ -409,7 +413,8 @@ async def add_prepended_conversation_to_memory_async( should_convert (bool): Whether to convert the prepended conversation converter_configurations (list[PromptConverterConfiguration] | None): Configurations for converting the request - attack_identifier (ComponentIdentifier | None): Identifier for the attack + attack_identifier (ComponentIdentifier | None): Identifier for the attack. + Deprecated: this parameter is ignored and will be removed in release 0.17.0. prepended_conversation (list[Message] | None): The conversation to prepend Returns: @@ -418,6 +423,13 @@ async def add_prepended_conversation_to_memory_async( if not prepended_conversation: return None + if attack_identifier is not None: + print_deprecation_message( + old_item="add_prepended_conversation_to_memory_async(..., attack_identifier=...)", + new_item="add_prepended_conversation_to_memory_async(...)", + removed_in="0.17.0", + ) + # Create a deep copy of the prepended conversation to avoid modifying the original prepended_conversation = copy.deepcopy(prepended_conversation) @@ -426,8 +438,6 @@ async def add_prepended_conversation_to_memory_async( await self.convert_values_async(message=request, converter_configurations=converter_configurations) for piece in request.message_pieces: piece.conversation_id = conversation_id - if attack_identifier: - piece.attack_identifier = attack_identifier # if the piece is retrieved from somewhere else, it needs to be unique # and if not, this won't hurt anything diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 218f46d52e..91393732fa 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -295,6 +295,7 @@ def set_system_prompt( system_prompt (str): The system prompt text to set. conversation_id (str): The conversation id to attach the prompt to. attack_identifier (ComponentIdentifier | None): Optional attack identifier. + Deprecated: this parameter is ignored and will be removed in release 0.17.0. labels (dict[str, str] | None): Optional labels. Raises: @@ -308,6 +309,13 @@ def set_system_prompt( removed_in="0.16.0", ) + if attack_identifier is not None: + print_deprecation_message( + old_item="set_system_prompt(..., attack_identifier=...)", + new_item="set_system_prompt(...)", + removed_in="0.17.0", + ) + if not self.capabilities.supports_multi_turn or not self.capabilities.supports_editable_history: raise ValueError( f"Target {type(self).__name__} does not support setting a system prompt. " @@ -326,7 +334,6 @@ def set_system_prompt( original_value=system_prompt, converted_value=system_prompt, prompt_target_identifier=self.get_identifier(), - attack_identifier=attack_identifier, labels=labels or {}, ).to_message() ) diff --git a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py index abe1d89b06..c287c6e392 100644 --- a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py +++ b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py @@ -30,7 +30,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pyrit.models import ComponentIdentifier from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target.common.realtime_audio import CommittedEvent @@ -127,7 +126,6 @@ def __init__( response_converter_configurations: list[PromptConverterConfiguration] | None = None, prepended_conversation: list[Message] | None = None, server_vad: bool | ServerVadConfig = True, - attack_identifier: ComponentIdentifier | None = None, persist_prepended_conversation: bool = True, ) -> None: self._target = target @@ -137,7 +135,6 @@ def __init__( self._request_converter_configurations = request_converter_configurations or [] self._response_converter_configurations = response_converter_configurations or [] self._prepended_conversation = prepended_conversation or [] - self._attack_identifier = attack_identifier self._persist_prepended_conversation = persist_prepended_conversation # Normalize server_vad once at construction so config send and commit-time trim @@ -412,7 +409,6 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: converted_value_data_type="audio_path", conversation_id=self._conversation_id, prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) for cfg in self._request_converter_configurations: user_piece.converter_identifiers.extend(converter.get_identifier() for converter in cfg.converters) @@ -424,7 +420,6 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: original_value_data_type="text", conversation_id=self._conversation_id, prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) assistant_audio_piece = MessagePiece( role="assistant", @@ -432,7 +427,6 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: original_value_data_type="audio_path", conversation_id=self._conversation_id, prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) if result.interrupted: assistant_text_piece.prompt_metadata[STREAMING_INTERRUPTED_KEY] = True diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index e5ed421385..cd34c5748f 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -153,9 +153,8 @@ def open_streaming_session( server_vad: Server-side voice activity detection. ``True`` (default) enables VAD with default tuning. Pass a ``ServerVadConfig`` for custom tuning, or ``False`` to disable (sending streaming config will then raise). - attack_identifier: Stamped on every persisted user / assistant piece for - attribution. Pass the caller's identifier so live messages share the - provenance contract of prepended messages. + attack_identifier: Deprecated. This parameter is ignored and will be removed in + release 0.17.0. persist_prepended_conversation: When ``True`` (default), the session writes ``prepended_conversation`` to memory itself. Pass ``False`` when the caller already persisted the prepended conversation (e.g. via @@ -168,6 +167,12 @@ def open_streaming_session( (but not yielded). The session owns its websocket connection + dispatcher for the duration of ``run_async``. """ + if attack_identifier is not None: + print_deprecation_message( + old_item="open_streaming_session(..., attack_identifier=...)", + new_item="open_streaming_session(...)", + removed_in="0.17.0", + ) return _OpenAIRealtimeStreamingSession( target=self, audio_chunks=audio_chunks, @@ -177,7 +182,6 @@ def open_streaming_session( response_converter_configurations=response_converter_configurations, prepended_conversation=prepended_conversation, server_vad=server_vad, - attack_identifier=attack_identifier, persist_prepended_conversation=persist_prepended_conversation, ) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 6991996105..ad3d5f3d52 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -718,7 +718,6 @@ def _parse_response_output_section( conversation_id=message_piece.conversation_id, labels=message_piece.labels, # deprecated prompt_target_identifier=message_piece.prompt_target_identifier, - attack_identifier=message_piece.attack_identifier, original_value_data_type=piece_type, response_error=error or "none", ) @@ -826,5 +825,4 @@ def _make_tool_piece(self, output: dict[str, Any], call_id: str, *, reference_pi conversation_id=reference_piece.conversation_id, labels={"call_id": call_id}, # deprecated prompt_target_identifier=reference_piece.prompt_target_identifier, - attack_identifier=reference_piece.attack_identifier, ) diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index fa87ed9e4a..5f0c26a515 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -242,7 +242,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me conversation_id=request_piece.conversation_id, labels=request_piece.labels, # deprecated prompt_target_identifier=request_piece.prompt_target_identifier, - attack_identifier=request_piece.attack_identifier, original_value_data_type=piece_type, converted_value_data_type=piece_type, prompt_metadata=request_piece.prompt_metadata, diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d921b2e1cf..64ece08807 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -103,7 +103,6 @@ async def _score_async(self, message: Message, *, objective: str | None = None) conversation_id=original_piece.conversation_id, labels=original_piece.labels, # deprecated prompt_target_identifier=original_piece.prompt_target_identifier, - attack_identifier=original_piece.attack_identifier, original_value_data_type="text", converted_value_data_type="text", response_error="none", diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index e0501b0c12..8629b92bb0 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -5,7 +5,7 @@ from uuid import UUID from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, PromptDataType, Score, UnvalidatedScore +from pyrit.models import Message, PromptDataType, Score, UnvalidatedScore from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -146,7 +146,6 @@ async def _score_value_with_llm_async( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: @@ -164,7 +163,6 @@ async def _score_value_with_llm_async( description_output_key=description_output_key, metadata_output_key=metadata_output_key, category_output_key=category_output_key, - attack_identifier=attack_identifier, ) if score is None: raise ValueError("Score returned None") diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 85919178b0..f46b635110 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -94,7 +94,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._harm_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) # Modify the UnvalidatedScore parsing to check for 'score_value' diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 17105defb9..9631f944a9 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -148,7 +148,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, score_value_output_key=self._score_value_output_key, rationale_output_key=self._rationale_output_key, description_output_key=self._description_output_key, diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index f5f2e97bcf..750a86e7c6 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -453,7 +453,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, category=self._score_category, - attack_identifier=message_piece.attack_identifier, objective=objective, ) diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 92db37a06a..87e8e73b51 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -138,7 +138,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st prepended_text_message_piece=prepended_text, category=self._category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score( diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index f3cda9923b..a43e14b56a 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -350,7 +350,6 @@ def _create_text_piece_from_blocked(piece: MessagePiece) -> MessagePiece | None: prompt_metadata=piece.prompt_metadata, converter_identifiers=list(piece.converter_identifiers), # type: ignore[arg-type] prompt_target_identifier=piece.prompt_target_identifier, - attack_identifier=piece.attack_identifier, response_error="none", timestamp=piece.timestamp, ) @@ -676,7 +675,6 @@ async def _score_value_with_llm_async( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -710,8 +708,6 @@ async def _score_value_with_llm_async( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (ComponentIdentifier | None): The attack identifier. - Defaults to None. Returns: UnvalidatedScore: The score object containing the response from the target LLM. @@ -727,7 +723,6 @@ async def _score_value_with_llm_async( prompt_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, ) prompt_metadata: dict[str, str | int] = {"response_format": "json"} diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index cce672b642..9c526deff4 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -150,7 +150,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 71acd45a56..f706efbcbe 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -148,7 +148,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, score_value_output_key=self._score_value_output_key, rationale_output_key=self._rationale_output_key, description_output_key=self._description_output_key, diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index a2f5bc078e..0d05c67b76 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -92,7 +92,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index b5a5c2b80c..0ad8e598e4 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -194,7 +194,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index 0786d0db38..d315a2e4ed 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -229,7 +229,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st prepended_text_message_piece=prepended_text, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 47d8678b7d..58835daa14 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -251,7 +251,6 @@ def test_swaps_user_to_assistant(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -267,7 +266,6 @@ def test_swaps_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -286,7 +284,6 @@ def test_swaps_simulated_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -305,7 +302,6 @@ def test_skips_system_messages(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -322,7 +318,6 @@ def test_assigns_new_uuids(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -344,7 +339,6 @@ def test_preserves_message_content(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -356,7 +350,6 @@ def test_empty_prepended_conversation(self) -> None: result = get_adversarial_chat_messages( [], adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) @@ -371,7 +364,6 @@ def test_applies_labels(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels=labels, ) @@ -389,7 +381,6 @@ def test_labels_emit_deprecation_warning(self) -> None: get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels={"env": "prod"}, ) @@ -499,9 +490,8 @@ class TestConversationManagerInitialization: def test_init_with_required_parameters(self, attack_identifier: ComponentIdentifier) -> None: """Test initialization with only required parameters.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() - assert manager._attack_identifier == attack_identifier assert isinstance(manager._prompt_normalizer, PromptNormalizer) assert manager._memory is not None @@ -509,7 +499,7 @@ def test_init_with_custom_prompt_normalizer( self, attack_identifier: ComponentIdentifier, mock_prompt_normalizer: MagicMock ) -> None: """Test initialization with a custom prompt normalizer.""" - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) + manager = ConversationManager(prompt_normalizer=mock_prompt_normalizer) assert manager._prompt_normalizer == mock_prompt_normalizer @@ -525,7 +515,7 @@ class TestConversationRetrieval: def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: ComponentIdentifier) -> None: """Test get_conversation returns empty list for non-existent conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) result = manager.get_conversation(conversation_id) @@ -536,7 +526,7 @@ def test_get_conversation_returns_messages_in_order( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_conversation returns messages in order.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -553,7 +543,7 @@ def test_get_conversation_returns_messages_in_order( def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: ComponentIdentifier) -> None: """Test get_last_message returns None for empty conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) result = manager.get_last_message(conversation_id=conversation_id) @@ -564,7 +554,7 @@ def test_get_last_message_returns_last_piece( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns the most recent message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -582,7 +572,7 @@ def test_get_last_message_with_role_filter( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message with role filter returns correct message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -601,7 +591,7 @@ def test_get_last_message_with_role_filter_returns_none_when_no_match( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns None when no message matches role filter.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -629,7 +619,7 @@ def test_set_system_prompt_with_chat_target( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt calls target's set_system_prompt method.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) system_prompt = "You are a helpful assistant" labels = {"type": "system"} @@ -644,7 +634,6 @@ def test_set_system_prompt_with_chat_target( mock_chat_target.set_system_prompt.assert_called_once_with( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, labels=labels, ) @@ -652,7 +641,7 @@ def test_set_system_prompt_without_labels( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt works without labels.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) system_prompt = "You are a helpful assistant" @@ -670,7 +659,7 @@ def test_set_system_prompt_labels_emit_deprecation_warning( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test that passing labels emits deprecation warning.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() with patch( "pyrit.executor.attack.component.conversation_manager.print_deprecation_message" @@ -701,7 +690,7 @@ async def test_raises_error_for_empty_conversation_id( mock_attack_context: AttackContext, ) -> None: """Test that empty conversation_id raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() with pytest.raises(ValueError, match="conversation_id cannot be empty"): await manager.initialize_context_async( @@ -717,7 +706,7 @@ async def test_returns_default_state_for_no_prepended_conversation( mock_attack_context: AttackContext, ) -> None: """Test that no prepended conversation returns default state.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) state = await manager.initialize_context_async( @@ -736,7 +725,7 @@ async def test_merges_memory_labels( mock_chat_target: MagicMock, ) -> None: """Test that memory_labels are merged with context labels.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.memory_labels = {"context_key": "context_value"} @@ -759,7 +748,7 @@ async def test_adds_prepended_conversation_to_memory_for_chat_target( sample_conversation: list[Message], ) -> None: """Test that prepended conversation is added to memory for chat targets.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -781,7 +770,7 @@ async def test_converts_assistant_to_simulated_assistant( sample_assistant_piece: MessagePiece, ) -> None: """Test that assistant messages are converted to simulated_assistant.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = [Message(message_pieces=[sample_assistant_piece])] @@ -805,7 +794,7 @@ async def test_normalizes_for_non_chat_target_by_default( sample_conversation: list[Message], ) -> None: """Test that prepended conversation is normalized for non-chat targets by default.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -830,7 +819,7 @@ async def test_normalizes_for_non_chat_target_when_configured( sample_conversation: list[Message], ) -> None: """Test that non-chat target normalizes prepended conversation when configured.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -858,7 +847,7 @@ async def test_returns_turn_count_for_multi_turn_attacks( sample_conversation: list[Message], ) -> None: """Test that turn count is returned for multi-turn attacks.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -880,7 +869,7 @@ async def test_multipart_message_extracts_scores_from_all_pieces( sample_score: Score, ) -> None: """Test that multi-part assistant messages extract scores from all pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -957,7 +946,7 @@ async def test_prepended_conversation_ignores_true_scores( would incorrectly indicate the objective was already achieved. Only false scores are extracted to provide feedback rationale for continued attack attempts. """ - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -1056,7 +1045,7 @@ async def test_non_chat_target_behavior_normalize_is_default( sample_conversation: list[Message], ) -> None: """Test that non-chat targets normalize by default (no config), matching dataclass field default.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1081,7 +1070,7 @@ async def test_non_chat_target_behavior_raise_explicit( sample_conversation: list[Message], ) -> None: """Test that non_chat_target_behavior='raise' raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1108,7 +1097,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_creates_next_messag sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn creates next_message when none exists.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1135,7 +1124,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existin sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn prepends context to existing next_message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1164,7 +1153,7 @@ async def test_non_chat_target_behavior_normalize_returns_empty_state( sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn returns empty ConversationState (no turn tracking).""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1195,7 +1184,7 @@ async def test_apply_converters_to_roles_default_applies_to_all( """Test that converters are applied to all roles by default.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1221,7 +1210,7 @@ async def test_apply_converters_to_roles_user_only( """Test that converters are applied only to user role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1249,7 +1238,7 @@ async def test_apply_converters_to_roles_assistant_only( """Test that converters are applied only to assistant role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1277,7 +1266,7 @@ async def test_apply_converters_to_roles_empty_list_skips_all( """Test that empty roles list means no converters applied to any role.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1307,7 +1296,7 @@ async def test_message_normalizer_default_uses_conversation_context_normalizer( sample_conversation: list[Message], ) -> None: """Test that default normalizer produces Turn N format.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1340,7 +1329,7 @@ async def test_message_normalizer_custom_normalizer_is_used( mock_normalizer = MagicMock(spec=MessageStringNormalizer) mock_normalizer.normalize_string_async = AsyncMock(return_value="CUSTOM_FORMAT: test content") - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1423,7 +1412,7 @@ async def test_chat_target_ignores_non_chat_target_behavior( sample_conversation: list[Message], ) -> None: """Test that chat targets ignore non_chat_target_behavior setting.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1454,7 +1443,7 @@ async def test_config_with_max_turns_validation( mock_chat_target: MagicMock, ) -> None: """Test that config works correctly with max_turns validation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -1503,7 +1492,7 @@ async def test_adds_messages_to_memory( sample_conversation: list[Message], ) -> None: """Test that messages are added to memory.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1521,7 +1510,7 @@ async def test_assigns_conversation_id_to_all_pieces( sample_conversation: list[Message], ) -> None: """Test that conversation_id is assigned to all message pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) await manager.add_prepended_conversation_to_memory_async( @@ -1534,25 +1523,6 @@ async def test_assigns_conversation_id_to_all_pieces( for piece in msg.message_pieces: assert piece.conversation_id == conversation_id - async def test_assigns_attack_identifier_to_all_pieces( - self, - attack_identifier: ComponentIdentifier, - sample_conversation: list[Message], - ) -> None: - """Test that attack_identifier is assigned to all message pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) - conversation_id = str(uuid.uuid4()) - - await manager.add_prepended_conversation_to_memory_async( - prepended_conversation=sample_conversation, - conversation_id=conversation_id, - ) - - stored = manager.get_conversation(conversation_id) - for msg in stored: - for piece in msg.message_pieces: - assert piece.attack_identifier == attack_identifier - async def test_raises_error_when_exceeds_max_turns( self, attack_identifier: ComponentIdentifier, @@ -1560,7 +1530,7 @@ async def test_raises_error_when_exceeds_max_turns( sample_assistant_piece: MessagePiece, ) -> None: """Test that exceeding max_turns raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Create conversation with 2 assistant messages @@ -1583,7 +1553,7 @@ async def test_multipart_response_counts_as_one_turn( attack_identifier: ComponentIdentifier, ) -> None: """Test that a multi-part assistant response counts as only one turn.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) piece_conversation_id = str(uuid.uuid4()) @@ -1621,7 +1591,7 @@ async def test_returns_zero_for_empty_conversation( attack_identifier: ComponentIdentifier, ) -> None: """Test that empty conversation returns 0 turns.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1638,7 +1608,7 @@ async def test_applies_converters_when_provided( sample_user_piece: MessagePiece, ) -> None: """Test that converters are applied when provided.""" - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) + manager = ConversationManager(prompt_normalizer=mock_prompt_normalizer) conversation_id = str(uuid.uuid4()) conversation = [Message(message_pieces=[sample_user_piece])] converter_config = [PromptConverterConfiguration(converters=[])] @@ -1657,7 +1627,7 @@ async def test_handles_none_messages_gracefully( attack_identifier: ComponentIdentifier, ) -> None: """Test that None messages are handled gracefully.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1684,7 +1654,7 @@ async def test_preserves_piece_metadata( sample_user_piece: MessagePiece, ) -> None: """Test that piece metadata is preserved during processing.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add metadata to piece @@ -1712,7 +1682,7 @@ async def test_preserves_original_and_converted_values( sample_user_piece: MessagePiece, ) -> None: """Test that original and converted values are preserved.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) sample_user_piece.original_value = "Original message" @@ -1740,7 +1710,7 @@ async def test_handles_system_messages_in_prepended_conversation( sample_user_piece: MessagePiece, ) -> None: """Test that system messages are handled in prepended conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = [ diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index b10ff8e640..5164050b6f 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -568,7 +568,6 @@ async def test_get_objective_as_benign_question_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify message was created correctly (converted from seed group) @@ -616,7 +615,6 @@ async def test_get_benign_question_answer_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify template was rendered with benign request @@ -657,7 +655,6 @@ async def test_get_objective_as_question_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify template was rendered diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index 7e23fcbda6..bf9d61a627 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -418,7 +418,6 @@ async def test_send_prompt_to_target_with_all_configurations( assert call_args.kwargs["request_converter_configurations"] == request_converters assert call_args.kwargs["response_converter_configurations"] == response_converters assert call_args.kwargs["labels"] == {"test": "label"} - assert "attack_identifier" in call_args.kwargs async def test_send_prompt_handles_none_response(self, mock_target, mock_prompt_normalizer, basic_context): attack = PromptSendingAttack(objective_target=mock_target, prompt_normalizer=mock_prompt_normalizer) diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 8676199d51..f7c01c386f 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -295,7 +295,6 @@ async def test_send_skeleton_key_prompt_uses_correct_converters( assert call_args.kwargs["request_converter_configurations"] == request_converters assert call_args.kwargs["response_converter_configurations"] == response_converters assert call_args.kwargs["labels"] == {"test": "label"} - assert "attack_identifier" in call_args.kwargs @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/executor/attack/streaming/test_barge_in.py b/tests/unit/executor/attack/streaming/test_barge_in.py index fee1d181d9..7a4bc289fe 100644 --- a/tests/unit/executor/attack/streaming/test_barge_in.py +++ b/tests/unit/executor/attack/streaming/test_barge_in.py @@ -225,7 +225,6 @@ async def test_perform_async_opens_session_with_expected_kwargs(vad_target): assert kwargs["request_converter_configurations"] == attack._request_converters assert kwargs["response_converter_configurations"] == attack._response_converters assert kwargs["prepended_conversation"] == ctx.prepended_conversation - assert kwargs["attack_identifier"] == attack.get_identifier() assert kwargs["persist_prepended_conversation"] is False diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py index 4e633b8a9f..c088e0b0bc 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py @@ -90,7 +90,6 @@ async def test_converter_send_prompt_async_bad_json_exception_retries( original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py index 239a86d474..65c3805877 100644 --- a/tests/unit/memory/memory_interface/test_batching_scale.py +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -36,7 +36,6 @@ def _create_message_piece( converted_value_sha256=sha256, sequence=0, conversation_id=conversation_id or str(uuid.uuid4()), - attack_identifier=ComponentIdentifier.from_dict({"id": str(uuid.uuid4())}), ) diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 34252b7547..aafb83fb12 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -19,7 +19,7 @@ def test_export_conversation_by_attack_id_file_created( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): - attack1_id = sample_conversations[0].attack_identifier.hash + attack1_id = "attack-1" # Default path in export_conversations() file_name = f"{attack1_id}.json" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index f1261b9597..a703bb0ed4 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AttackResult, ComponentIdentifier, IdentifierFilter, IdentifierType, @@ -21,6 +22,7 @@ MessagePiece, Score, SeedPrompt, + build_atomic_attack_identifier, ) @@ -135,7 +137,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", @@ -143,14 +144,12 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="I'm fine, thank you!", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=conversation_id_3, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="user", @@ -158,7 +157,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id_2, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", @@ -166,7 +164,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="I'm fine, thank you!", conversation_id=conversation_id_2, sequence=1, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -179,28 +176,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): ) all_pieces = sqlite_instance.get_message_pieces() assert len(all_pieces) == 9 - # Attack IDs are preserved (not changed) when duplicating - assert all(p.attack_identifier is not None for p in all_pieces) - assert ( - len( - [ - p - for p in all_pieces - if p.attack_identifier is not None and p.attack_identifier.hash == attack1.get_identifier().hash - ] - ) - == 8 - ) - assert ( - len( - [ - p - for p in all_pieces - if p.attack_identifier is not None and p.attack_identifier.hash == attack2.get_identifier().hash - ] - ) - == 1 - ) assert len([p for p in all_pieces if p.conversation_id == conversation_id_1]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_2]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_3]) == 1 @@ -223,7 +198,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -233,7 +207,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac converted_value="I'm fine, thank you!", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), ] @@ -276,8 +249,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac for piece in new_pieces: assert piece.id not in (prompt_id_1, prompt_id_2) assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 - # Attack ID is preserved, so both original and duplicated pieces have the same attack ID - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 @@ -294,14 +265,12 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter original_value="original prompt text", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", @@ -309,7 +278,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="I'm fine, thank you!", sequence=2, conversation_id=conversation_id_1, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="user", @@ -317,7 +285,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="Hello, how are you?", conversation_id=conversation_id_2, sequence=2, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="assistant", @@ -325,7 +292,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="I'm fine, thank you!", conversation_id=conversation_id_2, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -359,7 +325,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -369,7 +334,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="I'm fine, thank you!", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -378,7 +342,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="That's good.", conversation_id=conversation_id, sequence=2, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -387,7 +350,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="Thanks.", conversation_id=conversation_id, sequence=3, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), ] @@ -430,8 +392,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M assert new_pieces[0].id != prompt_id_1 assert new_pieces[1].id != prompt_id_2 assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 - # Attack ID is preserved - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 @@ -445,28 +405,24 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: original_value="original prompt text", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", original_value="original prompt text", conversation_id=conversation_id_1, sequence=2, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -486,39 +442,6 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: assert piece.sequence < 2 -def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=get_mock_target()) - conversation_id = "11111" - pieces = [ - MessagePiece( - role="user", - original_value="original prompt text", - converted_value="Hello, how are you?", - conversation_id=conversation_id, - sequence=0, - attack_identifier=attack1.get_identifier(), - ), - ] - sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - assert len(sqlite_instance.get_message_pieces()) == 1 - - # Duplicating preserves the attack ID - new_conversation_id = sqlite_instance.duplicate_conversation( - conversation_id=conversation_id, - ) - - # Verify duplication succeeded - all_pieces = sqlite_instance.get_message_pieces() - assert len(all_pieces) == 2 - assert new_conversation_id != conversation_id - - # Both pieces should have the same attack ID - assert all(p.attack_identifier is not None for p in all_pieces) - attack_ids = {p.attack_identifier.hash for p in all_pieces if p.attack_identifier is not None} - assert len(attack_ids) == 1 - assert attack1.get_identifier().hash in attack_ids - - def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface): """Test that duplicated conversation has new piece IDs.""" attack1 = PromptSendingAttack(objective_target=get_mock_target()) @@ -529,7 +452,6 @@ def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface converted_value="Hello", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[original_piece]) @@ -560,7 +482,6 @@ def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: Me original_value="traceable prompt", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[original_piece]) original_prompt_id = original_piece.original_prompt_id @@ -586,21 +507,18 @@ def test_duplicate_conversation_with_multiple_pieces(sqlite_instance: MemoryInte original_value="user message 1", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="assistant response 1", conversation_id=conversation_id, sequence=2, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", original_value="user message 2", conversation_id=conversation_id, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -932,31 +850,29 @@ def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) - entries = [ - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 1", - attack_identifier=attack1.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="assistant", - original_value="Hello 2", - attack_identifier=attack2.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 3", - attack_identifier=attack1.get_identifier(), - ) - ), + pieces = [ + MessagePiece(role="user", original_value="Hello 1", conversation_id="c1", sequence=0), + MessagePiece(role="assistant", original_value="Hello 2", conversation_id="c2", sequence=0), + MessagePiece(role="user", original_value="Hello 3", conversation_id="c1", sequence=1), ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - sqlite_instance._insert_entries(entries=entries) + # attack_identifier is no longer stamped on pieces; the deprecated attack_id filter + # resolves to an attack's main conversation via persisted AttackResults. + sqlite_instance.add_attack_results_to_memory( + attack_results=[ + AttackResult( + conversation_id="c1", + objective="objective 1", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack1.get_identifier()), + ), + AttackResult( + conversation_id="c2", + objective="objective 2", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack2.get_identifier()), + ), + ] + ) attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier().hash) @@ -1115,7 +1031,6 @@ def test_get_message_pieces_with_non_matching_memory_labels(sqlite_instance: Mem role="user", original_value="Hello 3", converted_value="Hello 1", - attack_identifier=attack.get_identifier(), ) ), ] @@ -1371,53 +1286,21 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) - attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) - entries = [ - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 1", - attack_identifier=attack1.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="assistant", - original_value="Hello 2", - attack_identifier=attack2.get_identifier(), - ) - ), - ] - - sqlite_instance._insert_entries(entries=entries) - - # Filter by exact attack hash - results = sqlite_instance.get_message_pieces( - identifier_filters=[ - IdentifierFilter( - identifier_type=IdentifierType.ATTACK, - property_path="$.hash", - value=attack1.get_identifier().hash, - partial_match=False, - ) - ], - ) - assert len(results) == 1 - assert results[0].original_value == "Hello 1" - - # No match - results = sqlite_instance.get_message_pieces( - identifier_filters=[ - IdentifierFilter( - identifier_type=IdentifierType.ATTACK, - property_path="$.hash", - value="nonexistent_hash", - partial_match=False, - ) - ], - ) - assert len(results) == 0 + # IdentifierType.ATTACK is no longer stamped on message pieces, so the piece-level + # identifier filter rejects it. Attack filtering now goes through get_attack_results + # or the deprecated attack_id parameter. + with pytest.raises(ValueError, match="does not support identifier type"): + sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value=attack1.get_identifier().hash, + partial_match=False, + ) + ], + ) def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): @@ -1432,24 +1315,20 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, ) - entries = [ - PromptMemoryEntry( - entry=MessagePiece( + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[ + MessagePiece( role="user", original_value="Hello OpenAI", prompt_target_identifier=target_id_1, - ) - ), - PromptMemoryEntry( - entry=MessagePiece( + ), + MessagePiece( role="user", original_value="Hello Azure", prompt_target_identifier=target_id_2, - ) - ), - ] - - sqlite_instance._insert_entries(entries=entries) + ), + ] + ) # Filter by target hash results = sqlite_instance.get_message_pieces( diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 7f786c260c..ead87d6666 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -8,17 +8,17 @@ from uuid import uuid4 import pytest -from unit.mocks import get_mock_target -from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AttackResult, ComponentIdentifier, IdentifierFilter, IdentifierType, MessagePiece, Score, SeedPrompt, + build_atomic_attack_identifier, ) @@ -41,6 +41,19 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_message_pieces_to_memory(message_pieces=sample_conversations) + # attack_identifier is no longer stamped on pieces; the deprecated attack_id filter + # resolves to an attack's main conversation via persisted AttackResults. + attack_strategy_id = ComponentIdentifier(class_name="TestAttack", class_module="test.module") + sqlite_instance.add_attack_results_to_memory( + attack_results=[ + AttackResult( + conversation_id=sample_conversations[0].conversation_id, + objective="test objective", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_strategy_id), + ) + ] + ) + score = Score( score_value=str(0.8), score_value_description="High score", @@ -55,8 +68,7 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_scores_to_memory(scores=[score]) # Fetch the score we just added - assert sample_conversations[0].attack_identifier is not None - db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier.hash) + db_score = sqlite_instance.get_prompt_scores(attack_id=attack_strategy_id.hash) assert len(db_score) == 1 assert db_score[0].score_value == score.score_value @@ -76,9 +88,8 @@ def test_get_scores_by_attack_id_and_label( assert len(db_score) == 1 assert db_score[0].score_value == score.score_value - assert sample_conversations[0].attack_identifier is not None db_score = sqlite_instance.get_prompt_scores( - attack_id=sample_conversations[0].attack_identifier.hash, + attack_id=attack_strategy_id.hash, labels={"x": "y"}, ) assert len(db_score) == 0 @@ -161,7 +172,6 @@ def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: Memor def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4() - attack = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = str(uuid4()) pieces = [ MessagePiece( @@ -171,12 +181,11 @@ def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack.get_identifier(), ) ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) sqlite_instance.duplicate_conversation(conversation_id=conversation_id) - # Get the duplicated piece (it will have a different conversation_id but same attack_id) + # Get the duplicated piece (it will have a different conversation_id) all_pieces = sqlite_instance.get_message_pieces() dupe_piece = [p for p in all_pieces if p.id != original_id][0] dupe_id = dupe_piece.id diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index b8ee5bb6dc..385fd26809 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -4,7 +4,6 @@ import os import uuid from collections.abc import Generator, MutableSequence, Sequence -from datetime import timezone from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch @@ -197,49 +196,40 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): converter_identifiers = [Base64Converter().get_identifier()] target = TextTarget() - # Start a session - with memory_interface.get_session() as session: # type: ignore[arg-type] - # Create a ConversationData entry with all attributes filled - entry = PromptMemoryEntry( - entry=MessagePiece( - conversation_id=specific_conversation_id, - role="user", - sequence=1, - original_value="Test content", - converted_value="Test content", - labels={"normalizer_id": "id1"}, - converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), - ) - ) + piece = MessagePiece( + conversation_id=specific_conversation_id, + role="user", + sequence=1, + original_value="Test content", + converted_value="Test content", + labels={"normalizer_id": "id1"}, + converter_identifiers=converter_identifiers, + prompt_target_identifier=target.get_identifier(), + ) + + memory_interface.add_message_pieces_to_memory(message_pieces=[piece]) + + # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id + retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) + + # Verify that the retrieved entry matches the inserted entry + assert len(retrieved_entries) == 1 + retrieved_entry = retrieved_entries[0].message_pieces[0] + assert retrieved_entry.conversation_id == specific_conversation_id + assert retrieved_entry.api_role == "user" + assert retrieved_entry.original_value == "Test content" + # For timestamp, you might want to check if it's close to the current time instead of an exact match + assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 10 # Assuming the test runs quickly + + converter_identifiers = retrieved_entry.converter_identifiers + assert len(converter_identifiers) == 1 + assert converter_identifiers[0].class_name == "Base64Converter" + + prompt_target = retrieved_entry.prompt_target_identifier + assert prompt_target.class_name == "TextTarget" - # Insert the ConversationData entry - session.add(entry) - session.commit() - - # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) - - # Verify that the retrieved entry matches the inserted entry - assert len(retrieved_entries) == 1 - retrieved_entry = retrieved_entries[0].message_pieces[0] - assert retrieved_entry.conversation_id == specific_conversation_id - assert retrieved_entry.api_role == "user" - assert retrieved_entry.original_value == "Test content" - # For timestamp, you might want to check if it's close to the current time instead of an exact match - assert ( - abs((retrieved_entry.timestamp - entry.timestamp.replace(tzinfo=timezone.utc)).total_seconds()) < 10 - ) # Assuming the test runs quickly - - converter_identifiers = retrieved_entry.converter_identifiers - assert len(converter_identifiers) == 1 - assert converter_identifiers[0].class_name == "Base64Converter" - - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" - - labels = retrieved_entry.labels - assert labels["normalizer_id"] == "id1" + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" def test_get_memories_with_attack_id(memory_interface: AzureSQLMemory): diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index 3f3b7ac990..ee3bab305f 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -53,7 +53,6 @@ def _make_message_piece(**overrides) -> MessagePiece: "prompt_metadata": {"meta": "data"}, "converter_identifiers": [ComponentIdentifier(class_name="NoOp", class_module="pyrit.converters")], "prompt_target_identifier": ComponentIdentifier(class_name="MockTarget", class_module="tests.mocks"), - "attack_identifier": ComponentIdentifier(class_name="MockAttack", class_module="tests.mocks"), "original_value_data_type": "text", "converted_value_data_type": "text", "response_error": "none", @@ -224,16 +223,6 @@ def test_init_stores_converter_identifiers_as_dicts(self): assert isinstance(entry.converter_identifiers, list) assert isinstance(entry.converter_identifiers[0], dict) - def test_init_with_no_attack_identifier(self): - piece = _make_message_piece(attack_identifier=None) - entry = PromptMemoryEntry(entry=piece) - assert entry.attack_identifier == {} - - def test_init_with_no_target_identifier(self): - piece = _make_message_piece(prompt_target_identifier=None) - entry = PromptMemoryEntry(entry=piece) - assert entry.prompt_target_identifier == {} - def test_roundtrip_get_message_piece(self): piece = _make_message_piece() entry = PromptMemoryEntry(entry=piece) @@ -245,13 +234,6 @@ def test_roundtrip_get_message_piece(self): assert recovered.conversation_id == piece.conversation_id assert isinstance(recovered.converter_identifiers[0], ComponentIdentifier) - def test_str_with_target_identifier(self): - piece = _make_message_piece() - entry = PromptMemoryEntry(entry=piece) - s = str(entry) - assert "MockTarget" in s - assert "user" in s - def test_str_without_target_identifier(self): piece = _make_message_piece(prompt_target_identifier=None) entry = PromptMemoryEntry(entry=piece) diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 3eddf11133..bc7265c360 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -7,7 +7,6 @@ import tempfile import uuid from collections.abc import Sequence -from datetime import timezone from unittest.mock import MagicMock import pytest @@ -58,7 +57,6 @@ def test_conversation_data_schema(sqlite_instance): "labels", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", "original_value_data_type", "original_value", "original_value_sha256", @@ -97,8 +95,6 @@ def test_conversation_data_column_types(sqlite_instance): "labels": (String, JSON), "prompt_metadata": (String, JSON), "converter_identifiers": (String, JSON), - "prompt_target_identifier": (String, JSON), - "attack_identifier": (String, JSON), "response_error": String, "original_value_data_type": String, "original_value": String, @@ -522,47 +518,40 @@ def test_get_memories_with_json_properties(sqlite_instance): converter_identifiers = [Base64Converter().get_identifier()] target = TextTarget() - # Start a session - with sqlite_instance.get_session() as session: - # Create a ConversationData entry with all attributes filled - piece = MessagePiece( - conversation_id=specific_conversation_id, - role="user", - sequence=1, - original_value="Test content", - converted_value="Test content", - labels={"normalizer_id": "id1"}, - converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), - ) - entry = PromptMemoryEntry(entry=piece) + piece = MessagePiece( + conversation_id=specific_conversation_id, + role="user", + sequence=1, + original_value="Test content", + converted_value="Test content", + labels={"normalizer_id": "id1"}, + converter_identifiers=converter_identifiers, + prompt_target_identifier=target.get_identifier(), + ) - # Insert the ConversationData entry - session.add(entry) - session.commit() + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) - # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) + # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id + retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) - # Verify that the retrieved entry matches the inserted entry - assert len(retrieved_entries) == 1 - retrieved_entry = retrieved_entries[0].message_pieces[0] - assert retrieved_entry.conversation_id == specific_conversation_id - assert retrieved_entry.api_role == "user" - assert retrieved_entry.original_value == "Test content" - # For timestamp, you might want to check if it's close to the current time instead of an exact match - assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 0.1 - assert abs((retrieved_entry.timestamp - entry.timestamp.replace(tzinfo=timezone.utc)).total_seconds()) < 0.1 + # Verify that the retrieved entry matches the inserted entry + assert len(retrieved_entries) == 1 + retrieved_entry = retrieved_entries[0].message_pieces[0] + assert retrieved_entry.conversation_id == specific_conversation_id + assert retrieved_entry.api_role == "user" + assert retrieved_entry.original_value == "Test content" + # For timestamp, you might want to check if it's close to the current time instead of an exact match + assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 0.1 - converter_identifiers = retrieved_entry.converter_identifiers - assert len(converter_identifiers) == 1 - assert converter_identifiers[0].class_name == "Base64Converter" + converter_identifiers = retrieved_entry.converter_identifiers + assert len(converter_identifiers) == 1 + assert converter_identifiers[0].class_name == "Base64Converter" - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" + prompt_target = retrieved_entry.prompt_target_identifier + assert prompt_target.class_name == "TextTarget" - labels = retrieved_entry.labels - assert labels["normalizer_id"] == "id1" + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" def test_update_entries(sqlite_instance): diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index dbd1a8a4d4..2616bd1fd5 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -150,7 +150,6 @@ def set_system_prompt( original_value=system_prompt, converted_value=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, labels=labels or {}, ).to_message() ) @@ -165,7 +164,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me role="assistant", original_value="default", conversation_id=message.message_pieces[0].conversation_id, - attack_identifier=message.message_pieces[0].attack_identifier, labels=message.message_pieces[0].labels, ).to_message() ] @@ -259,7 +257,6 @@ def get_test_message_piece() -> MessagePiece: def get_sample_conversations() -> MutableSequence[Message]: with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): conversation_1 = str(uuid.uuid4()) - attack_id = get_mock_attack_identifier() return [ MessagePiece( @@ -268,7 +265,6 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="Hello, how are you?", conversation_id=conversation_1, sequence=0, - attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", @@ -276,14 +272,12 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="I'm fine, thank you!", conversation_id=conversation_1, sequence=1, - attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=str(uuid.uuid4()), - attack_identifier=attack_id, ).to_message(), ] diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index ea50d4de7e..b8c64404d5 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -387,7 +387,6 @@ def test_to_dict_from_dict_roundtrip(): sequence=1, timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), prompt_target_identifier=target_id, - attack_identifier=attack_id, ) last_score = Score( score_value="true", diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index e8d4457d81..d4872ccec9 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -10,9 +10,8 @@ from unittest.mock import patch import pytest -from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations +from unit.mocks import MockPromptTarget, get_sample_conversations -from pyrit.executor.attack import PromptSendingAttack from pyrit.models import ( ComponentIdentifier, Message, @@ -83,21 +82,6 @@ def test_prompt_targets_serialize(patch_central_database): assert entry.prompt_target_identifier.class_module == "unit.mocks" -def test_executors_serialize(): - attack = PromptSendingAttack(objective_target=get_mock_target()) - - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - attack_identifier=attack.get_identifier(), - ) - - assert entry.attack_identifier.hash is not None - assert entry.attack_identifier.class_name == "PromptSendingAttack" - assert entry.attack_identifier.class_module == "pyrit.executor.attack.single_turn.prompt_sending" - - async def test_hashes_generated(): entry = MessagePiece( role="user", @@ -693,10 +677,6 @@ def test_message_piece_to_dict(): class_name="MockPromptTarget", class_module="unit.mocks", ), - attack_identifier=ComponentIdentifier( - class_name="PromptSendingAttack", - class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", - ), scorer_identifier=ComponentIdentifier( class_name="TestScorer", class_module="pyrit.score.test_scorer", @@ -740,7 +720,6 @@ def test_message_piece_to_dict(): "prompt_metadata", "converter_identifiers", "prompt_target_identifier", - "attack_identifier", "scorer_identifier", "original_value_data_type", "original_value", @@ -768,7 +747,6 @@ def test_message_piece_to_dict(): assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() - assert result["attack_identifier"] == entry.attack_identifier.to_dict() assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value @@ -1094,7 +1072,6 @@ def test_to_dict_from_dict_roundtrip(): prompt_metadata={"doc_type": "text"}, converter_identifiers=[converter_id], prompt_target_identifier=target_id, - attack_identifier=attack_id, original_value_data_type="text", converted_value_data_type="text", response_error="none", @@ -1141,7 +1118,6 @@ def _make_piece(self, **overrides) -> MessagePiece: def test_copies_lineage_fields_from_source_to_target(self) -> None: source = self._make_piece( conversation_id="conv-A", - attack_identifier={"__type__": "Attack", "__module__": "x", "id": "atk-1"}, prompt_target_identifier={"__type__": "Target", "__module__": "x", "id": "tgt-1"}, ) source.prompt_metadata = {"k": "v"} @@ -1151,7 +1127,6 @@ def test_copies_lineage_fields_from_source_to_target(self) -> None: target.copy_lineage_from(source=source) assert target.conversation_id == "conv-A" - assert target.attack_identifier == source.attack_identifier assert target.prompt_target_identifier == source.prompt_target_identifier assert target.prompt_metadata == {"k": "v"} @@ -1224,7 +1199,6 @@ def test_to_dict_golden_shape(self) -> None: "prompt_metadata", "converter_identifiers", "prompt_target_identifier", - "attack_identifier", "scorer_identifier", "scores", ] @@ -1239,7 +1213,6 @@ def test_to_dict_golden_shape(self) -> None: assert d["prompt_metadata"] == {} assert d["converter_identifiers"] == [] assert d["prompt_target_identifier"] is None - assert d["attack_identifier"] is None assert d["scorer_identifier"] is None assert d["original_value_data_type"] == "text" assert d["original_value"] == "hello" diff --git a/tests/unit/prompt_converter/test_persuasion_converter.py b/tests/unit/prompt_converter/test_persuasion_converter.py index 197e5564ef..02cd38301c 100644 --- a/tests/unit/prompt_converter/test_persuasion_converter.py +++ b/tests/unit/prompt_converter/test_persuasion_converter.py @@ -73,7 +73,6 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/prompt_converter/test_variation_converter.py b/tests/unit/prompt_converter/test_variation_converter.py index 542fccf0c1..1357894a2a 100644 --- a/tests/unit/prompt_converter/test_variation_converter.py +++ b/tests/unit/prompt_converter/test_variation_converter.py @@ -45,7 +45,6 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index decc056441..899a75b1d9 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -614,7 +614,6 @@ def test_memory_property_raises_when_memory_none(): async def test_add_prepended_conversation_to_memory(mock_memory_instance): normalizer = PromptNormalizer() conv_id = "test-conv-id" - attack_id = get_mock_attack_identifier() piece = MessagePiece(role="user", original_value="prepended text", conversation_id="old-id") message = Message(message_pieces=[piece]) @@ -622,14 +621,12 @@ async def test_add_prepended_conversation_to_memory(mock_memory_instance): result = await normalizer.add_prepended_conversation_to_memory_async( conversation_id=conv_id, should_convert=False, - attack_identifier=attack_id, prepended_conversation=[message], ) assert result is not None assert len(result) == 1 assert result[0].message_pieces[0].conversation_id == conv_id - assert result[0].message_pieces[0].attack_identifier == attack_id mock_memory_instance.add_message_to_memory.assert_called_once() diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 2bd58d18a0..5a97782ce6 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -37,7 +37,6 @@ def _make_message_piece(*, role: str, content: str, conversation_id: str = "conv original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 6483c931ea..6da03a1311 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -284,7 +284,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -295,7 +294,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -394,7 +392,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -405,7 +402,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -459,7 +455,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -470,7 +465,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py index c765e7e4d6..97c71618c1 100644 --- a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py +++ b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py @@ -12,7 +12,7 @@ import pytest -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_target.common.realtime_audio import ( STREAMING_INTERRUPTED_KEY, CommittedEvent, @@ -642,73 +642,6 @@ async def _fire() -> None: assert len(calls[2].args[0]) == 14400 -# --------------------------------------------------------------------------- -# 9. attack_identifier is stamped on persisted user + assistant pieces -# --------------------------------------------------------------------------- - - -async def test_attack_identifier_stamped_on_persisted_pieces_when_set(): - """When ``attack_identifier`` is provided, every persisted piece carries it.""" - target = _build_target() - normalizer = _build_normalizer() - - persisted_messages: list[Message] = [] - - async def _capture(*, message: Message) -> None: - persisted_messages.append(message) - - normalizer.hash_and_persist_message_async = AsyncMock(side_effect=_capture) - - attack_id = ComponentIdentifier(class_name="BargeInAttack", class_module="test") - - finish = asyncio.Event() - session = _OpenAIRealtimeStreamingSession( - target=target, - audio_chunks=_paced_chunks([b"\x01" * 96], finish), - prompt_normalizer=normalizer, - attack_identifier=attack_id, - ) - _mock_session_wire(session) - - with _patched_dispatcher(): - await _run_session_with_events(session, finish=finish, events=[CommittedEvent(item_id="i")]) - - # Expect one user message + one assistant message (two pieces) — three pieces total. - all_pieces = [piece for msg in persisted_messages for piece in msg.message_pieces] - assert len(all_pieces) == 3 - for piece in all_pieces: - assert piece.attack_identifier == attack_id - - -async def test_attack_identifier_absent_when_not_provided(): - """Without ``attack_identifier``, persisted pieces have None attribution (back-compat).""" - target = _build_target() - normalizer = _build_normalizer() - - persisted_messages: list[Message] = [] - - async def _capture(*, message: Message) -> None: - persisted_messages.append(message) - - normalizer.hash_and_persist_message_async = AsyncMock(side_effect=_capture) - - finish = asyncio.Event() - session = _OpenAIRealtimeStreamingSession( - target=target, - audio_chunks=_paced_chunks([b"\x01" * 96], finish), - prompt_normalizer=normalizer, - ) - _mock_session_wire(session) - - with _patched_dispatcher(): - await _run_session_with_events(session, finish=finish, events=[CommittedEvent(item_id="i")]) - - all_pieces = [piece for msg in persisted_messages for piece in msg.message_pieces] - assert len(all_pieces) == 3 - for piece in all_pieces: - assert piece.attack_identifier is None - - # --------------------------------------------------------------------------- # 10. persist_prepended_conversation=False skips the prepended-memory write # --------------------------------------------------------------------------- @@ -770,7 +703,6 @@ async def _empty(): req_cfgs = [MagicMock(name="req_cfg")] resp_cfgs = [MagicMock(name="resp_cfg")] vad = ServerVadConfig(prefix_padding_ms=42) - attack_id = {"__type__": "BargeInAttack", "id": "x"} captured: dict[str, Any] = {} @@ -790,7 +722,6 @@ def _fake_session_ctor(**kwargs): response_converter_configurations=resp_cfgs, prepended_conversation=prepended, server_vad=vad, - attack_identifier=attack_id, persist_prepended_conversation=False, ) @@ -802,7 +733,6 @@ def _fake_session_ctor(**kwargs): assert captured["response_converter_configurations"] is resp_cfgs assert captured["prepended_conversation"] is prepended assert captured["server_vad"] is vad - assert captured["attack_identifier"] is attack_id assert captured["persist_prepended_conversation"] is False diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index 3e90cbc00c..da758724b4 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -307,7 +307,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -318,7 +317,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -400,7 +398,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -411,7 +408,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -446,7 +442,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di original_value_data_type="text", converted_value_data_type="text", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -457,7 +452,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di original_value_data_type="image_path", converted_value_data_type="image_path", prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 4ad66bdbde..8b89fcb175 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -60,7 +60,6 @@ def test_set_system_prompt(azure_openai_target: OpenAIChatTarget, mock_attack_st azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -76,7 +75,6 @@ async def test_set_system_prompt_adds_memory( azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -110,7 +108,6 @@ async def test_send_prompt_with_system_calls_chat_complete( azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -164,7 +161,6 @@ async def test_send_prompt_async_with_delay( _LINEAGE_CONVERSATION_ID = "original-conv-id-12345" _LINEAGE_LABELS = {"op_name": "test_op", "user_id": "user42"} -_LINEAGE_ATTACK_IDENTIFIER = ComponentIdentifier(class_name="TestAttack", class_module="tests.attacks") _LINEAGE_PROMPT_TARGET_IDENTIFIER = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit") _LINEAGE_PROMPT_METADATA = {"scenario": "test_scenario", "turn": 3} @@ -179,7 +175,6 @@ def _make_lineage_piece(*, role: str, content: str) -> MessagePiece: converted_value_data_type="text", labels=dict(_LINEAGE_LABELS), prompt_target_identifier=_LINEAGE_PROMPT_TARGET_IDENTIFIER, - attack_identifier=_LINEAGE_ATTACK_IDENTIFIER, prompt_metadata=dict(_LINEAGE_PROMPT_METADATA), ) @@ -242,7 +237,6 @@ async def test_history_squash_preserves_metadata_on_normalized_message(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -291,7 +285,6 @@ async def test_response_preserves_metadata_after_history_squash(): assert response_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert response_piece.labels == _LINEAGE_LABELS - assert response_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER assert response_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert response_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -338,7 +331,6 @@ async def test_system_squash_preserves_metadata(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -389,7 +381,6 @@ async def test_history_squash_propagates_lineage_to_all_pieces(): for piece in normalized[0].message_pieces: assert piece.conversation_id == _LINEAGE_CONVERSATION_ID assert piece.labels == _LINEAGE_LABELS - assert piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER assert piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -453,7 +444,6 @@ async def test_conversation_id_stamped_on_all_but_full_lineage_only_on_last(): # Last message should carry full lineage. last_piece = normalized[-1].message_pieces[0] assert last_piece.labels == _LINEAGE_LABELS - assert last_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER assert last_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert last_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index 0e957482a2..c326f8bf15 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -252,7 +252,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data conversation_id=conversation_id, labels={"test": "label"}, prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) @@ -289,7 +288,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data assert called_piece.conversation_id == message_piece.conversation_id assert called_piece.labels == message_piece.labels assert called_piece.prompt_target_identifier == message_piece.prompt_target_identifier - assert called_piece.attack_identifier == message_piece.attack_identifier async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(patch_central_database): diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 01378ebde3..6491ccefeb 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -205,71 +205,6 @@ async def test_scorer_score_value_with_llm_exception_display_prompt_id(): ) -async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_json): - scorer = MockScorer() - - message = Message( - message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] - ) - chat_target = MagicMock(PromptTarget) - chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") - chat_target.send_prompt_async = AsyncMock(return_value=[message]) - chat_target.set_system_prompt = MagicMock() - - expected_system_prompt = "system_prompt" - expected_attack_identifier = ComponentIdentifier(class_name="TestAttack", class_module="test.module") - expected_scored_prompt_id = "123" - - await scorer._score_value_with_llm_async( - prompt_target=chat_target, - system_prompt=expected_system_prompt, - message_value="message_value", - message_data_type="text", - scored_prompt_id=expected_scored_prompt_id, - category="category", - objective="task", - attack_identifier=expected_attack_identifier, - ) - - chat_target.set_system_prompt.assert_called_once() - - _, set_sys_prompt_args = chat_target.set_system_prompt.call_args - assert set_sys_prompt_args["system_prompt"] == expected_system_prompt - assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert set_sys_prompt_args["attack_identifier"] is expected_attack_identifier - - -async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empty_attack_identifier(good_json): - scorer = MockScorer() - - message = Message( - message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] - ) - chat_target = MagicMock(PromptTarget) - chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") - chat_target.send_prompt_async = AsyncMock(return_value=[message]) - chat_target.set_system_prompt = MagicMock() - - expected_system_prompt = "system_prompt" - - await scorer._score_value_with_llm_async( - prompt_target=chat_target, - system_prompt=expected_system_prompt, - message_value="message_value", - message_data_type="text", - scored_prompt_id="123", - category="category", - objective="task", - ) - - chat_target.set_system_prompt.assert_called_once() - - _, set_sys_prompt_args = chat_target.set_system_prompt.call_args - assert set_sys_prompt_args["system_prompt"] == expected_system_prompt - assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert not set_sys_prompt_args["attack_identifier"] - - async def test_scorer_send_chat_target_async_good_response(good_json): chat_target = MagicMock(PromptTarget) chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") From a7e5fe1d60636ee8f29fb522616db821c56d3aed Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sat, 6 Jun 2026 22:11:30 -0700 Subject: [PATCH 02/12] Fix tests and docs for Conversation model cutover Migrate remaining test sites off the removed MessagePiece.prompt_target_identifier field, thread target_identifier through duplication call sites, and update the memory schema/data-type docs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/code/memory/10_schema_diagram.md | 7 ++- doc/code/memory/3_memory_data_types.md | 4 +- pyrit/backend/services/attack_service.py | 26 ++++++++-- .../attack/component/conversation_manager.py | 13 +++-- .../multi_turn/multi_turn_attack_strategy.py | 4 +- .../executor/attack/multi_turn/red_teaming.py | 5 +- .../attack/multi_turn/tree_of_attacks.py | 5 +- .../promptgen/fuzzer/fuzzer_converter_base.py | 1 - .../fuzzer/fuzzer_crossover_converter.py | 1 - .../fuzzer/fuzzer_expand_converter.py | 1 - pyrit/memory/azure_sql_memory.py | 15 +++--- pyrit/memory/memory_interface.py | 49 ++++++++++++++----- pyrit/memory/memory_models.py | 18 ------- pyrit/memory/sqlite_memory.py | 19 +++---- pyrit/models/messages/conversations.py | 1 - pyrit/models/messages/message_piece.py | 4 +- pyrit/models/seeds/seed_group.py | 1 - .../llm_generic_text_converter.py | 1 - pyrit/prompt_normalizer/prompt_normalizer.py | 29 +++++++---- .../common/discover_target_capabilities.py | 12 +++-- pyrit/prompt_target/common/prompt_target.py | 4 +- .../_openai_realtime_streaming_session.py | 12 +++-- .../openai/openai_response_target.py | 2 - .../playwright_copilot_target.py | 1 - pyrit/prompt_target/text_target.py | 5 +- pyrit/score/conversation_scorer.py | 1 - pyrit/score/scorer.py | 3 -- pyrit/score/true_false/gandalf_scorer.py | 1 - .../score/true_false/prompt_shield_scorer.py | 1 - tests/unit/backend/test_attack_service.py | 11 ++++- .../component/test_conversation_manager.py | 9 ---- .../test_supports_multi_turn_attacks.py | 20 ++++---- .../promptgen/fuzzer/test_fuzzer_converter.py | 3 +- .../test_interface_prompts.py | 12 +++-- tests/unit/memory/test_azure_sql_memory.py | 11 +++-- tests/unit/memory/test_memory_models.py | 5 +- tests/unit/memory/test_sqlite_memory.py | 11 +++-- tests/unit/models/test_attack_result.py | 6 --- tests/unit/models/test_message_piece.py | 26 +--------- .../test_persuasion_converter.py | 6 +-- .../test_translation_converter.py | 5 +- .../test_variation_converter.py | 6 +-- .../test_prompt_normalizer.py | 1 + .../prompt_target/target/test_http_target.py | 6 --- .../test_normalize_async_integration.py | 3 +- .../target/test_openai_chat_target.py | 8 +-- .../test_openai_realtime_streaming_session.py | 4 +- .../target/test_openai_response_target.py | 8 +-- .../target/test_prompt_target.py | 7 --- .../prompt_target/test_round_robin_target.py | 2 - .../score/test_conversation_history_scorer.py | 2 - 51 files changed, 201 insertions(+), 217 deletions(-) diff --git a/doc/code/memory/10_schema_diagram.md b/doc/code/memory/10_schema_diagram.md index 9fbcefc2a2..40837b71d8 100644 --- a/doc/code/memory/10_schema_diagram.md +++ b/doc/code/memory/10_schema_diagram.md @@ -40,14 +40,16 @@ flowchart LR P_labels["labels (VARCHAR)"] P_prompt_metadata["prompt_metadata (VARCHAR)"] P_converter_identifiers["converter_identifiers (VARCHAR)"] - P_prompt_target_identifier["prompt_target_identifier (VARCHAR)"] - P_attack_identifier["attack_identifier (VARCHAR)"] P_response_error["response_error (VARCHAR)"] P_converted_value_data_type["converted_value_data_type (VARCHAR)"] P_converted_value["converted_value (VARCHAR)"] P_converted_value_sha256["converted_value_sha256 (VARCHAR)"] P_original_prompt_id["original_prompt_id (UUID)"] end + subgraph Conversations["Conversations"] + C_conversation_id["conversation_id (VARCHAR)"] + C_target_identifier["target_identifier (VARCHAR)"] + end subgraph ScoreEntries["ScoreEntries"] Sc_id["id (UUID)"] Sc_prompt_request_response_id["prompt_request_response_id (VARCHAR)"] @@ -63,6 +65,7 @@ flowchart LR end S_value_sha256 -- N:N relationship to query --> P_original_value_sha256 P_id -- 1:N relationship to query --> Sc_prompt_request_response_id + P_conversation_id -- N:1 relationship to query --> C_conversation_id style S_value_sha256 fill:#ff8800ff style P_id fill:#14a519ff diff --git a/doc/code/memory/3_memory_data_types.md b/doc/code/memory/3_memory_data_types.md index b35daaa005..0fd1b0988d 100644 --- a/doc/code/memory/3_memory_data_types.md +++ b/doc/code/memory/3_memory_data_types.md @@ -23,8 +23,6 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../ - **`labels`**: Dictionary of labels for categorization and filtering - **`prompt_metadata`**: Component-specific metadata (e.g., blob URIs, document types) - **`converter_identifiers`**: List of converters applied to transform the prompt -- **`prompt_target_identifier`**: Information about the target that received this prompt -- **`attack_identifier`**: Information about the attack that generated this prompt - **`scorer_identifier`**: Information about the scorer that evaluated this prompt - **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`) - **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`) @@ -54,6 +52,8 @@ This rich context allows PyRIT to track the full lifecycle of each interaction, A conversation is a list of `Messages` that share the same `conversation_id`. The sequence of the `MessagePieces` and their corresponding `Messages` dictates the order of the conversation. +A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory.get_conversation_metadata(conversation_id=...)` to retrieve it. + Here is a sample conversation made up of three `Messages` which all share the same conversation ID. The first `Message` is the `system` message, followed by a multi-modal `user` prompt with a text `MessagePiece` and an image `MessagePiece`, and finally the `assistant` response in the form of a text `MessagePiece`. ```{mermaid} diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 836a783fba..ea95ccf275 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -313,6 +313,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt cutoff_index=request.cutoff_index, labels_override=labels, remap_assistant_to_simulated=True, + target_identifier=target_identifier, ) else: conversation_id = str(uuid.uuid4()) @@ -345,6 +346,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt conversation_id=conversation_id, prepended=request.prepended_conversation, labels=labels, # deprecated + target_identifier=target_identifier, ) return CreateAttackResponse( @@ -476,9 +478,13 @@ async def create_related_conversation_async( # --- Branch via duplication (preferred for tracking) --------------- if request.source_conversation_id is not None and request.cutoff_index is not None: + source_metadata = self._memory.get_conversation_metadata( + conversation_id=request.source_conversation_id + ) new_conversation_id = self._duplicate_conversation_up_to( source_conversation_id=request.source_conversation_id, cutoff_index=request.cutoff_index, + target_identifier=source_metadata.target_identifier if source_metadata else None, ) else: new_conversation_id = str(uuid.uuid4()) @@ -623,11 +629,13 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR labels=attack_labels, # deprecated ) else: + existing_metadata = self._memory.get_conversation_metadata(conversation_id=msg_conversation_id) await self._store_message_only_async( conversation_id=msg_conversation_id, request=request, sequence=sequence, labels=attack_labels, # deprecated + target_identifier=existing_metadata.target_identifier if existing_metadata else None, ) await self._update_attack_after_message_async(attack_result_id=attack_result_id, ar=ar, request=request) @@ -829,6 +837,7 @@ def _duplicate_conversation_up_to( cutoff_index: int, labels_override: dict[str, str] | None = None, remap_assistant_to_simulated: bool = False, + target_identifier: ComponentIdentifier | None = None, ) -> str: """ Duplicate messages from a conversation up to and including a turn index. @@ -847,6 +856,9 @@ def _duplicate_conversation_up_to( ``assistant`` are changed to ``simulated_assistant`` so the branched context is inert and won't confuse the target. + target_identifier (ComponentIdentifier | None): The target the new conversation + is held with, if known. Recorded once for the duplicated conversation. + Returns: The new conversation ID containing the duplicated messages. """ @@ -866,7 +878,9 @@ def _duplicate_conversation_up_to( piece.role = "simulated_assistant" if all_pieces: - self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces)) + self._memory.add_message_pieces_to_memory( + message_pieces=list(all_pieces), target_identifier=target_identifier + ) return new_conversation_id @@ -954,6 +968,7 @@ async def _store_prepended_messages_async( conversation_id: str, prepended: list[Any], labels: dict[str, str] | None = None, # deprecated + target_identifier: ComponentIdentifier | None = None, ) -> None: """Store prepended conversation messages in memory.""" for seq, msg in enumerate(prepended): @@ -965,7 +980,9 @@ async def _store_prepended_messages_async( sequence=seq, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + self._memory.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target_identifier + ) async def _send_and_store_message_async( self, @@ -1011,6 +1028,7 @@ async def _store_message_only_async( request: AddMessageRequest, sequence: int, labels: dict[str, str] | None = None, # deprecated + target_identifier: ComponentIdentifier | None = None, ) -> None: """Store message without sending (send=False).""" await self._persist_base64_pieces_async(request) @@ -1022,7 +1040,9 @@ async def _store_message_only_async( sequence=sequence, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + self._memory.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target_identifier + ) def _resolve_video_remix_metadata(self, request: AddMessageRequest) -> None: """ diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 86127c765a..8a3cbac557 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -54,7 +54,6 @@ def get_adversarial_chat_messages( prepended_conversation: list[Message], *, adversarial_chat_conversation_id: str, - adversarial_chat_target_identifier: ComponentIdentifier, labels: dict[str, str] | None = None, # deprecated ) -> list[Message]: """ @@ -71,7 +70,6 @@ def get_adversarial_chat_messages( Args: prepended_conversation: The original conversation messages to transform. adversarial_chat_conversation_id: Conversation ID for the adversarial chat. - adversarial_chat_target_identifier (ComponentIdentifier): Target identifier for the adversarial chat. labels: Optional labels to associate with the messages. Deprecated: This parameter will be removed in a release 0.16.0. @@ -112,7 +110,6 @@ def get_adversarial_chat_messages( original_value_data_type=piece.original_value_data_type, converted_value_data_type=piece.converted_value_data_type, conversation_id=adversarial_chat_conversation_id, - prompt_target_identifier=adversarial_chat_target_identifier, labels=labels or {}, # deprecated ) @@ -352,6 +349,7 @@ async def initialize_context_async( request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, + target_identifier=target.get_identifier(), ) async def _handle_non_chat_target_async( @@ -432,6 +430,7 @@ async def add_prepended_conversation_to_memory_async( request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: int | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> int: """ Add prepended conversation messages to memory for a chat target. @@ -452,6 +451,8 @@ async def add_prepended_conversation_to_memory_async( request_converters: Optional converters to apply to messages. prepended_conversation_config: Optional configuration for converter roles. max_turns: If provided, validates that turn count doesn't exceed this limit. + target_identifier (ComponentIdentifier | None): The target the conversation is held + with, if known. Recorded once per conversation. Returns: The number of turns (assistant messages) added. @@ -498,7 +499,7 @@ async def add_prepended_conversation_to_memory_async( ) # Add to memory - self._memory.add_message_to_memory(request=message_copy) + self._memory.add_message_to_memory(request=message_copy, target_identifier=target_identifier) logger.debug(f"Added prepended message {i + 1}/{len(valid_messages)} to memory") return turn_count @@ -512,6 +513,7 @@ async def _process_prepended_for_chat_target_async( request_converters: list[PromptConverterConfiguration] | None, prepended_conversation_config: Optional["PrependedConversationConfig"], max_turns: int | None, + target_identifier: ComponentIdentifier | None = None, ) -> ConversationState: """ Process prepended conversation for a chat target. @@ -528,6 +530,8 @@ async def _process_prepended_for_chat_target_async( request_converters: Converters to apply. prepended_conversation_config: Configuration for converter roles. max_turns: Maximum turns for validation. + target_identifier (ComponentIdentifier | None): The objective target the + conversation is held with, if known. Returns: ConversationState with turn_count and scores. @@ -547,6 +551,7 @@ async def _process_prepended_for_chat_target_async( request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, + target_identifier=target_identifier, ) # Update context for multi-turn attacks to reflect prepended_conversation diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 4aca72d054..bb3b07cef2 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -141,7 +141,9 @@ def _rotate_conversation_for_single_turn_target( if system_messages: new_conversation_id, pieces = memory.duplicate_messages(messages=system_messages) - memory.add_message_pieces_to_memory(message_pieces=pieces) + memory.add_message_pieces_to_memory( + message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + ) context.session.conversation_id = new_conversation_id else: context.session.conversation_id = str(uuid.uuid4()) diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index def0bd0113..08c18fc4c4 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -269,12 +269,13 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: adversarial_messages = get_adversarial_chat_messages( prepended_conversation=context.prepended_conversation, adversarial_chat_conversation_id=context.session.adversarial_chat_conversation_id, - adversarial_chat_target_identifier=self._adversarial_chat.get_identifier(), labels=context.memory_labels, ) for msg in adversarial_messages: - self._memory.add_message_to_memory(request=msg) + self._memory.add_message_to_memory( + request=msg, target_identifier=self._adversarial_chat.get_identifier() + ) async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> AttackResult: """ diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 0e40119c98..22d0541e9d 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -412,6 +412,7 @@ async def initialize_with_prepended_conversation_async( conversation_id=self.objective_target_conversation_id, request_converters=self._request_converters, prepended_conversation_config=prepended_conversation_config, + target_identifier=self._objective_target.get_identifier(), ) # Build context string for adversarial chat system prompt (like Crescendo) @@ -820,7 +821,9 @@ def duplicate(self) -> "_TreeOfAttacksNode": system_messages = [m for m in messages if m.api_role == "system"] if system_messages: new_id, pieces = self._memory.duplicate_messages(messages=system_messages) - self._memory.add_message_pieces_to_memory(message_pieces=pieces) + self._memory.add_message_pieces_to_memory( + message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + ) duplicate_node.objective_target_conversation_id = new_id else: duplicate_node.objective_target_conversation_id = str(uuid.uuid4()) diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py index dd9519bdf4..868526279f 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py @@ -96,7 +96,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 2f7ddf5be4..05eb69aab7 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -98,7 +98,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index e8fae2fa18..91b3ea127b 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -71,7 +71,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 2dfc3b183d..d7c84d9769 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -28,6 +28,7 @@ ) from pyrit.models import ( AzureBlobStorageIO, + ComponentIdentifier, ConversationStats, MessagePiece, ) @@ -698,7 +699,9 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. @@ -708,6 +711,8 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] Args: message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. Applied to every distinct ``conversation_id``. """ # ``not_in_memory`` pieces are ephemeral — typically synthesized inside a # scorer to score arbitrary content that never came through a real @@ -718,7 +723,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return - self._capture_conversations(message_pieces=pieces_to_insert) + self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def dispose_engine(self) -> None: @@ -830,18 +835,12 @@ def _query_entries( if join_scores and model_class == PromptMemoryEntry: query = query.options( joinedload(PromptMemoryEntry.scores), - joinedload(PromptMemoryEntry.conversation_metadata), ) elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), - joinedload(AttackResultEntry.last_response).joinedload( - PromptMemoryEntry.conversation_metadata - ), joinedload(AttackResultEntry.last_score), ) - elif model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.conversation_metadata)) if conditions is not None: query = query.filter(conditions) if order_by is not None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index e062185a2d..7fd8695757 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -347,12 +347,22 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> An """ @abc.abstractmethod - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. + + Args: + message_pieces (Sequence[MessagePiece]): The pieces to persist. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. A conversation is always with a single target, so + this is applied to every distinct ``conversation_id`` in ``message_pieces``. """ - def _capture_conversations(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def _capture_conversations( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Record one ``Conversations`` row per conversation for the given pieces. @@ -362,17 +372,26 @@ def _capture_conversations(self, *, message_pieces: Sequence[MessagePiece]) -> N -- normalizer, conversation duplication, prepended conversations, direct target writers -- captures the target through a single choke point. + A conversation is always held with a single target, so ``target_identifier`` + (when provided) is applied to every distinct ``conversation_id`` in this call. + A ``None`` target never overwrites a target already recorded for the + conversation (see ``_upsert_conversation``). + Args: message_pieces (Sequence[MessagePiece]): The pieces being persisted. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. """ - targets_by_conversation: dict[str, ComponentIdentifier | None] = {} + conversation_ids: list[str] = [] + seen: set[str] = set() for piece in message_pieces: if piece.not_in_memory: continue conversation_id = piece.conversation_id - if targets_by_conversation.get(conversation_id) is None: - targets_by_conversation[conversation_id] = piece.prompt_target_identifier - for conversation_id, target_identifier in targets_by_conversation.items(): + if conversation_id not in seen: + seen.add(conversation_id) + conversation_ids.append(conversation_id) + for conversation_id in conversation_ids: self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) def _upsert_conversation( @@ -1171,8 +1190,10 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: The uuid for the new conversation. """ messages = self.get_conversation(conversation_id=conversation_id) + source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) - self.add_message_pieces_to_memory(message_pieces=all_pieces) + self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) return new_conversation_id def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> str: @@ -1204,12 +1225,16 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> message for message in messages if message.sequence <= last_message.sequence - length_of_sequence_to_remove ] + source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) - self.add_message_pieces_to_memory(message_pieces=all_pieces) + self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) return new_conversation_id - def add_message_to_memory(self, *, request: Message) -> None: + def add_message_to_memory( + self, *, request: Message, target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. @@ -1217,7 +1242,9 @@ def add_message_to_memory(self, *, request: Message) -> None: If necessary, generates embedding data for applicable entries Args: - request (MessagePiece): The message piece to add to the memory. + request (Message): The message to add to the memory. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. Forwarded to ``add_message_pieces_to_memory``. """ request.validate() @@ -1226,7 +1253,7 @@ def add_message_to_memory(self, *, request: Message) -> None: self._update_sequence(message_pieces=message_pieces) - self.add_message_pieces_to_memory(message_pieces=message_pieces) + self.add_message_pieces_to_memory(message_pieces=message_pieces, target_identifier=target_identifier) if self.memory_embedding: for piece in message_pieces: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index d31012cce9..6e62af9b11 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -24,7 +24,6 @@ from sqlalchemy.orm import ( DeclarativeBase, Mapped, - foreign, mapped_column, relationship, ) @@ -293,18 +292,6 @@ class PromptMemoryEntry(Base): foreign_keys="ScoreEntry.prompt_request_response_id", ) - # Conversation-scoped metadata (e.g. the target identifier) lives in the - # ``Conversations`` table keyed by ``conversation_id`` rather than on every row. - # ``viewonly`` because this join is read-only (there is no FK constraint); reads - # eager-load it via ``joinedload`` so detached entries can still hydrate the - # target onto the reconstructed ``MessagePiece``. - conversation_metadata: Mapped["ConversationEntry | None"] = relationship( - "ConversationEntry", - primaryjoin=lambda: foreign(PromptMemoryEntry.conversation_id) == ConversationEntry.conversation_id, - viewonly=True, - uselist=False, - ) - def __init__(self, *, entry: MessagePiece) -> None: """ Initialize a PromptMemoryEntry from a MessagePiece. @@ -370,11 +357,6 @@ def get_message_piece(self) -> MessagePiece: message_piece.labels = self.labels or {} message_piece.targeted_harm_categories = self.targeted_harm_categories or [] message_piece.scores = [score.get_score() for score in self.scores] - # The target identifier is conversation-scoped: hydrate it from the - # ``Conversations`` row (eager-loaded via ``conversation_metadata``) so it is - # served once per conversation rather than stored on every piece. - if self.conversation_metadata is not None: - message_piece.prompt_target_identifier = self.conversation_metadata.get_conversation().target_identifier return message_piece def __str__(self) -> str: diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index ce961f2fbb..df5649ef14 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -29,7 +29,7 @@ PromptMemoryEntry, ScenarioResultEntry, ) -from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece +from pyrit.models import ComponentIdentifier, ConversationStats, DiskStorageIO, MessagePiece logger = logging.getLogger(__name__) @@ -301,18 +301,25 @@ def _get_condition_json_array_match( combined = joiner.join(conditions) return text(f"({combined})").bindparams(**bindparams_dict) - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. Pieces flagged via ``MessagePiece.not_in_memory = True`` are silently filtered out so callers don't need to track persistence policy themselves. + + Args: + message_pieces (Sequence[MessagePiece]): The pieces to persist. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. Applied to every distinct ``conversation_id``. """ pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return - self._capture_conversations(message_pieces=pieces_to_insert) + self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: @@ -364,19 +371,13 @@ def _query_entries( if join_scores and model_class == PromptMemoryEntry: query = query.options( joinedload(PromptMemoryEntry.scores), - joinedload(PromptMemoryEntry.conversation_metadata), ) elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response) .joinedload(PromptMemoryEntry.scores), - joinedload(AttackResultEntry.last_response).joinedload( - PromptMemoryEntry.conversation_metadata - ), joinedload(AttackResultEntry.last_score), ) - elif model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.conversation_metadata)) if conditions is not None: query = query.filter(conditions) if order_by is not None: diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index 32bbf0f0be..e4cb34a121 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -205,7 +205,6 @@ def construct_response_from_request( original_value=resp_text, conversation_id=request.conversation_id, labels=request.labels, - prompt_target_identifier=request.prompt_target_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata or {}, diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index e9a25a3121..e922a53d28 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -113,7 +113,6 @@ class MessagePiece(BaseModel): targeted_harm_categories: list[str] = Field(default_factory=list) prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) - prompt_target_identifier: ComponentIdentifierField | None = None scorer_identifier: ComponentIdentifierField | None = None scores: list[Score] = Field(default_factory=list) @@ -219,7 +218,7 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: Copy lineage metadata from ``source`` onto this piece. Lineage fields are the metadata that tie a piece back to its originating - conversation and target. Mutable containers (``labels``, + conversation. Mutable containers (``labels``, ``prompt_metadata``) are shallow-copied so that mutations on one piece do not affect others. @@ -228,7 +227,6 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: """ self.conversation_id = source.conversation_id self.labels = dict(source.labels) - self.prompt_target_identifier = source.prompt_target_identifier self.prompt_metadata = dict(source.prompt_metadata) def has_error(self) -> bool: diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 3998820c38..5a37c58bd3 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -450,7 +450,6 @@ def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> list[Message]: role=role, original_value=prompt.value, original_value_data_type=prompt.data_type or "text", - prompt_target_identifier=None, conversation_id=str(prompt.prompt_group_id), sequence=sequence, prompt_metadata=prompt.metadata, diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index ae7f763ab8..210e6a5ab9 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -174,7 +174,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=converted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self._converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 9df030fd1a..55cc40b68c 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -118,12 +118,12 @@ async def send_prompt_async( # Prepare the request by updating conversation ID, labels, and attack identifier request = copy.deepcopy(message) conversation_id = conversation_id if conversation_id else str(uuid4()) + target_identifier = target.get_identifier() for piece in request.message_pieces: piece.conversation_id = conversation_id if labels: piece.labels = labels # deprecated - piece.prompt_target_identifier = target.get_identifier() # Apply request converters await self.convert_values_async(converter_configurations=request_converter_configurations, message=request) @@ -134,10 +134,10 @@ async def send_prompt_async( try: responses = await target.send_prompt_async(message=request) - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) responses = [ construct_response_from_request( @@ -150,7 +150,7 @@ async def send_prompt_async( except Exception as ex: # Ensure request to memory before processing exception - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) error_response = construct_response_from_request( request=request.message_pieces[0], @@ -160,7 +160,7 @@ async def send_prompt_async( ) await self._calc_hash_async(request=error_response) - self.memory.add_message_to_memory(request=error_response) + self.memory.add_message_to_memory(request=error_response, target_identifier=target_identifier) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex @@ -177,7 +177,7 @@ async def send_prompt_async( error="empty", ) await self._calc_hash_async(request=empty_response) - self.memory.add_message_to_memory(request=empty_response) + self.memory.add_message_to_memory(request=empty_response, target_identifier=target_identifier) return empty_response # Process all response messages (targets return list[Message]) @@ -190,7 +190,7 @@ async def send_prompt_async( converter_configurations=response_converter_configurations, message=resp ) await self._calc_hash_async(request=resp) - self.memory.add_message_to_memory(request=resp) + self.memory.add_message_to_memory(request=resp, target_identifier=target_identifier) # Return the last response for backward compatibility return responses[-1] @@ -384,7 +384,9 @@ async def _calc_hash_async(self, request: Message) -> None: tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.message_pieces] await asyncio.gather(*tasks) - async def hash_and_persist_message_async(self, *, message: Message) -> None: + async def hash_and_persist_message_async( + self, *, message: Message, target_identifier: ComponentIdentifier | None = None + ) -> None: """ Hash and persist a Message to memory. @@ -393,9 +395,11 @@ async def hash_and_persist_message_async(self, *, message: Message) -> None: Args: message (Message): The message to hash and persist. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. """ await self._calc_hash_async(request=message) - self.memory.add_message_to_memory(request=message) + self.memory.add_message_to_memory(request=message, target_identifier=target_identifier) async def add_prepended_conversation_to_memory_async( self, @@ -404,6 +408,7 @@ async def add_prepended_conversation_to_memory_async( converter_configurations: list[PromptConverterConfiguration] | None = None, attack_identifier: ComponentIdentifier | None = None, prepended_conversation: list[Message] | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> list[Message] | None: """ Process the prepended conversation by converting it if needed and adding it to memory. @@ -416,6 +421,8 @@ async def add_prepended_conversation_to_memory_async( attack_identifier (ComponentIdentifier | None): Identifier for the attack. Deprecated: this parameter is ignored and will be removed in release 0.17.0. prepended_conversation (list[Message] | None): The conversation to prepend + target_identifier (ComponentIdentifier | None): The target the conversation is held + with, if known. Recorded once per conversation. Returns: list[Message] | None: The processed prepended conversation @@ -443,7 +450,7 @@ async def add_prepended_conversation_to_memory_async( # and if not, this won't hurt anything piece.id = uuid4() - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) return prepended_conversation @@ -467,6 +474,7 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt converter_configurations: list[PromptConverterConfiguration] | None = None, attack_identifier: ComponentIdentifier | None = None, prepended_conversation: list[Message] | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> list[Message] | None: """ Use ``add_prepended_conversation_to_memory_async`` instead; this is a deprecated alias. @@ -485,6 +493,7 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt converter_configurations=converter_configurations, attack_identifier=attack_identifier, prepended_conversation=prepended_conversation, + target_identifier=target_identifier, ) diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 44e0313b89..ba91daee80 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -322,7 +322,9 @@ async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, ret prompt_metadata=_probe_metadata(), ) try: - target._memory.add_message_to_memory(request=Message(message_pieces=[system_piece])) + target._memory.add_message_to_memory( + request=Message(message_pieces=[system_piece]), target_identifier=target.get_identifier() + ) except Exception as exc: logger.debug("System-prompt probe could not seed system message: %s", exc) return False @@ -406,7 +408,9 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie # Seed memory so the second send sees real prior history. try: - target._memory.add_message_to_memory(request=Message(message_pieces=[first])) + target._memory.add_message_to_memory( + request=Message(message_pieces=[first]), target_identifier=target.get_identifier() + ) assistant_reply = MessagePiece( role="assistant", original_value="Got it.", @@ -414,7 +418,9 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie conversation_id=conversation_id, prompt_metadata=_probe_metadata(), ).to_message() - target._memory.add_message_to_memory(request=assistant_reply) + target._memory.add_message_to_memory( + request=assistant_reply, target_identifier=target.get_identifier() + ) except Exception as exc: logger.debug("Multi-turn probe could not seed conversation history: %s", exc) return False diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 91393732fa..3ff6416bb4 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -333,9 +333,9 @@ def set_system_prompt( conversation_id=conversation_id, original_value=system_prompt, converted_value=system_prompt, - prompt_target_identifier=self.get_identifier(), labels=labels or {}, - ).to_message() + ).to_message(), + target_identifier=self.get_identifier(), ) def dispose_db_engine(self) -> None: diff --git a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py index c287c6e392..6d7a0767e4 100644 --- a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py +++ b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py @@ -203,6 +203,7 @@ async def run_async(self) -> AsyncIterator[Message]: conversation_id=self._conversation_id, should_convert=False, prepended_conversation=self._prepended_conversation, + target_identifier=self._target.get_identifier(), ) self._queue = asyncio.Queue() @@ -408,7 +409,6 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: converted_value=converted_user_path, converted_value_data_type="audio_path", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, ) for cfg in self._request_converter_configurations: user_piece.converter_identifiers.extend(converter.get_identifier() for converter in cfg.converters) @@ -419,14 +419,12 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: original_value=result.flatten_transcripts(), original_value_data_type="text", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, ) assistant_audio_piece = MessagePiece( role="assistant", original_value=assistant_audio_path, original_value_data_type="audio_path", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, ) if result.interrupted: assistant_text_piece.prompt_metadata[STREAMING_INTERRUPTED_KEY] = True @@ -439,8 +437,12 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: message=assistant_message, ) - await self._prompt_normalizer.hash_and_persist_message_async(message=user_message) - await self._prompt_normalizer.hash_and_persist_message_async(message=assistant_message) + await self._prompt_normalizer.hash_and_persist_message_async( + message=user_message, target_identifier=target_identifier + ) + await self._prompt_normalizer.hash_and_persist_message_async( + message=assistant_message, target_identifier=target_identifier + ) return assistant_message # ---- Wire helpers ------------------------------------------------------- diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index ad3d5f3d52..f0a797af93 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -717,7 +717,6 @@ def _parse_response_output_section( original_value=piece_value, conversation_id=message_piece.conversation_id, labels=message_piece.labels, # deprecated - prompt_target_identifier=message_piece.prompt_target_identifier, original_value_data_type=piece_type, response_error=error or "none", ) @@ -824,5 +823,4 @@ def _make_tool_piece(self, output: dict[str, Any], call_id: str, *, reference_pi original_value_data_type="function_call_output", conversation_id=reference_piece.conversation_id, labels={"call_id": call_id}, # deprecated - prompt_target_identifier=reference_piece.prompt_target_identifier, ) diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 5f0c26a515..d545aa58ca 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -241,7 +241,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me original_value=piece_data, conversation_id=request_piece.conversation_id, labels=request_piece.labels, # deprecated - prompt_target_identifier=request_piece.prompt_target_identifier, original_value_data_type=piece_type, converted_value_data_type=piece_type, prompt_metadata=request_piece.prompt_metadata, diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 8e0deed295..ba69a8dcdb 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -86,12 +86,13 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: sequence=int(sequence_str) if sequence_str else 0, labels=labels, # deprecated response_error=row.get("response_error", None), - prompt_target_identifier=self.get_identifier(), ) message_pieces.append(message_piece) # This is post validation, so the message_pieces should be okay and normalized - self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) + self._memory.add_message_pieces_to_memory( + message_pieces=message_pieces, target_identifier=self.get_identifier() + ) return message_pieces def _validate_request(self, *, normalized_conversation: list[Message]) -> None: diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index 64ece08807..9c0842ce38 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -102,7 +102,6 @@ async def _score_async(self, message: Message, *, objective: str | None = None) id=original_piece.id, conversation_id=original_piece.conversation_id, labels=original_piece.labels, # deprecated - prompt_target_identifier=original_piece.prompt_target_identifier, original_value_data_type="text", converted_value_data_type="text", response_error="none", diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index a43e14b56a..5cef5744e2 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -349,7 +349,6 @@ def _create_text_piece_from_blocked(piece: MessagePiece) -> MessagePiece | None: labels=piece.labels, prompt_metadata=piece.prompt_metadata, converter_identifiers=list(piece.converter_identifiers), # type: ignore[arg-type] - prompt_target_identifier=piece.prompt_target_identifier, response_error="none", timestamp=piece.timestamp, ) @@ -738,7 +737,6 @@ async def _score_value_with_llm_async( original_value_data_type="text", converted_value_data_type="text", conversation_id=conversation_id, - prompt_target_identifier=prompt_target.get_identifier(), prompt_metadata=prompt_metadata, ) ) @@ -751,7 +749,6 @@ async def _score_value_with_llm_async( original_value_data_type=message_data_type, converted_value_data_type=message_data_type, conversation_id=conversation_id, - prompt_target_identifier=prompt_target.get_identifier(), prompt_metadata=prompt_metadata, ) ) diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 5fc51fbc25..e28ad75bbc 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -128,7 +128,6 @@ async def _check_for_password_in_conversation_async(self, conversation_id: str) original_value=conversation_as_text, converted_value=conversation_as_text, conversation_id=scoring_conversation_id, - prompt_target_identifier=self._prompt_target.get_identifier(), ) ] ) diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index a320e89fa9..d6ac555610 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -76,7 +76,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st original_value=body, prompt_metadata=message_piece.prompt_metadata, conversation_id=conversation_id, - prompt_target_identifier=self._prompt_target.get_identifier(), ) ] ) diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index f02367f6f8..a268d0b744 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -2175,9 +2175,14 @@ class TestAttackServiceAdditionalCoverage: async def test_create_related_conversation_uses_duplicate_branch(self, attack_service, mock_memory): """When source_conversation_id and cutoff_index are provided, duplication path is used.""" from pyrit.backend.models.attacks import CreateConversationRequest + from pyrit.models import Conversation ar = make_attack_result(conversation_id="attack-1") mock_memory.get_attack_results.return_value = [ar] + expected_target = ComponentIdentifier(class_name="TextTarget", class_module="pyrit.prompt_target") + mock_memory.get_conversation_metadata.return_value = Conversation( + conversation_id="attack-1", target_identifier=expected_target + ) with patch.object(attack_service, "_duplicate_conversation_up_to", return_value="branch-dup") as mock_dup: result = await attack_service.create_related_conversation_async( @@ -2187,7 +2192,11 @@ async def test_create_related_conversation_uses_duplicate_branch(self, attack_se assert result is not None assert result.conversation_id == "branch-dup" - mock_dup.assert_called_once_with(source_conversation_id="attack-1", cutoff_index=2) + mock_dup.assert_called_once_with( + source_conversation_id="attack-1", + cutoff_index=2, + target_identifier=expected_target, + ) async def test_add_message_merges_converter_identifiers_without_duplicates(self, attack_service, mock_memory): """Should merge new converter identifiers with existing attack identifiers by hash.""" diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 58835daa14..d3d107732d 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -251,7 +251,6 @@ def test_swaps_user_to_assistant(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -266,7 +265,6 @@ def test_swaps_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -284,7 +282,6 @@ def test_swaps_simulated_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -302,7 +299,6 @@ def test_skips_system_messages(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # Only user message should be present, system skipped @@ -318,7 +314,6 @@ def test_assigns_new_uuids(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # New ID should be different from original @@ -339,7 +334,6 @@ def test_preserves_message_content(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result[0].get_piece().original_value == "Original content" @@ -350,7 +344,6 @@ def test_empty_prepended_conversation(self) -> None: result = get_adversarial_chat_messages( [], adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result == [] @@ -364,7 +357,6 @@ def test_applies_labels(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels=labels, ) @@ -381,7 +373,6 @@ def test_labels_emit_deprecation_warning(self) -> None: get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels={"env": "prod"}, ) diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index ae7ccaafaa..111c9e00bf 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -28,7 +28,7 @@ def _make_strategy(*, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} with patch.multiple( MultiTurnAttackStrategy, @@ -378,13 +378,13 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} scorer = MagicMock() - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} seed = MagicMock() seed.render_template_value.return_value = "template" @@ -694,14 +694,14 @@ def _make_single_turn_target(self): target.configuration = TargetConfiguration( capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), ) - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return target def _make_adversarial_config(self): from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return AttackAdversarialConfig(target=adversarial_chat) def _make_scoring_config(self): @@ -709,7 +709,7 @@ def _make_scoring_config(self): from pyrit.score import TrueFalseScorer scorer = MagicMock(spec=TrueFalseScorer) - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return AttackScoringConfig(objective_scorer=scorer) async def test_crescendo_raises_for_single_turn_target(self): @@ -752,13 +752,13 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} scorer = MagicMock() - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} seed = MagicMock() seed.render_template_value.return_value = "template" diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py index c088e0b0bc..e8466493da 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py @@ -14,7 +14,7 @@ FuzzerShortenConverter, FuzzerSimilarConverter, ) -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece @pytest.mark.parametrize( @@ -89,7 +89,6 @@ async def test_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index a703bb0ed4..ad5cd1ddb6 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -1320,14 +1320,20 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI MessagePiece( role="user", original_value="Hello OpenAI", - prompt_target_identifier=target_id_1, + conversation_id="conv-openai", ), + ], + target_identifier=target_id_1, + ) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[ MessagePiece( role="user", original_value="Hello Azure", - prompt_target_identifier=target_id_2, + conversation_id="conv-azure", ), - ] + ], + target_identifier=target_id_2, ) # Filter by target hash diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 385fd26809..ddcd509718 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -204,10 +204,11 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): converted_value="Test content", labels={"normalizer_id": "id1"}, converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), ) - memory_interface.add_message_pieces_to_memory(message_pieces=[piece]) + memory_interface.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target.get_identifier() + ) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) @@ -225,8 +226,10 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): assert len(converter_identifiers) == 1 assert converter_identifiers[0].class_name == "Base64Converter" - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" + # The target identifier is conversation-scoped and stored in the Conversations table. + metadata = memory_interface.get_conversation_metadata(conversation_id=specific_conversation_id) + assert metadata is not None + assert metadata.target_identifier.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index ee3bab305f..ca6912d749 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -52,7 +52,6 @@ def _make_message_piece(**overrides) -> MessagePiece: "labels": {"label1": "value1"}, "prompt_metadata": {"meta": "data"}, "converter_identifiers": [ComponentIdentifier(class_name="NoOp", class_module="pyrit.converters")], - "prompt_target_identifier": ComponentIdentifier(class_name="MockTarget", class_module="tests.mocks"), "original_value_data_type": "text", "converted_value_data_type": "text", "response_error": "none", @@ -234,8 +233,8 @@ def test_roundtrip_get_message_piece(self): assert recovered.conversation_id == piece.conversation_id assert isinstance(recovered.converter_identifiers[0], ComponentIdentifier) - def test_str_without_target_identifier(self): - piece = _make_message_piece(prompt_target_identifier=None) + def test_str_renders_role_and_value(self): + piece = _make_message_piece() entry = PromptMemoryEntry(entry=piece) s = str(entry) assert "user" in s diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index bc7265c360..95c05d40fc 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -526,10 +526,11 @@ def test_get_memories_with_json_properties(sqlite_instance): converted_value="Test content", labels={"normalizer_id": "id1"}, converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), ) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target.get_identifier() + ) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) @@ -547,8 +548,10 @@ def test_get_memories_with_json_properties(sqlite_instance): assert len(converter_identifiers) == 1 assert converter_identifiers[0].class_name == "Base64Converter" - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" + # The target identifier is conversation-scoped and stored in the Conversations table. + metadata = sqlite_instance.get_conversation_metadata(conversation_id=specific_conversation_id) + assert metadata is not None + assert metadata.target_identifier.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index b8c64404d5..1b64fb4378 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -370,11 +370,6 @@ def test_to_dict_from_dict_roundtrip(): class_name="SelfAskTrueFalseScorer", class_module="pyrit.score", ) - target_id = ComponentIdentifier( - class_name="OpenAIChatTarget", - class_module="pyrit.prompt_target", - params={"endpoint": "https://api.example.com"}, - ) attack_id = ComponentIdentifier( class_name="PromptSendingAttack", class_module="pyrit.executor.attack", @@ -386,7 +381,6 @@ def test_to_dict_from_dict_roundtrip(): conversation_id="conv-1", sequence=1, timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), - prompt_target_identifier=target_id, ) last_score = Score( score_value="true", diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index d4872ccec9..1809257d35 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -10,7 +10,7 @@ from unittest.mock import patch import pytest -from unit.mocks import MockPromptTarget, get_sample_conversations +from unit.mocks import get_sample_conversations from pyrit.models import ( ComponentIdentifier, @@ -69,19 +69,6 @@ def test_converters_serialize(): assert converter.class_module == "pyrit.prompt_converter.base64_converter" -def test_prompt_targets_serialize(patch_central_database): - target = MockPromptTarget() - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - prompt_target_identifier=target.get_identifier(), - ) - assert patch_central_database.called - assert entry.prompt_target_identifier.class_name == "MockPromptTarget" - assert entry.prompt_target_identifier.class_module == "unit.mocks" - - async def test_hashes_generated(): entry = MessagePiece( role="user", @@ -673,10 +660,6 @@ def test_message_piece_to_dict(): params={"supported_input_types": ["text"], "supported_output_types": ["text"]}, ) ], - prompt_target_identifier=ComponentIdentifier( - class_name="MockPromptTarget", - class_module="unit.mocks", - ), scorer_identifier=ComponentIdentifier( class_name="TestScorer", class_module="pyrit.score.test_scorer", @@ -719,7 +702,6 @@ def test_message_piece_to_dict(): "targeted_harm_categories", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", "scorer_identifier", "original_value_data_type", "original_value", @@ -746,7 +728,6 @@ def test_message_piece_to_dict(): assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] - assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value @@ -1071,7 +1052,6 @@ def test_to_dict_from_dict_roundtrip(): timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), prompt_metadata={"doc_type": "text"}, converter_identifiers=[converter_id], - prompt_target_identifier=target_id, original_value_data_type="text", converted_value_data_type="text", response_error="none", @@ -1118,7 +1098,6 @@ def _make_piece(self, **overrides) -> MessagePiece: def test_copies_lineage_fields_from_source_to_target(self) -> None: source = self._make_piece( conversation_id="conv-A", - prompt_target_identifier={"__type__": "Target", "__module__": "x", "id": "tgt-1"}, ) source.prompt_metadata = {"k": "v"} @@ -1127,7 +1106,6 @@ def test_copies_lineage_fields_from_source_to_target(self) -> None: target.copy_lineage_from(source=source) assert target.conversation_id == "conv-A" - assert target.prompt_target_identifier == source.prompt_target_identifier assert target.prompt_metadata == {"k": "v"} def test_labels_and_metadata_are_shallow_copied(self) -> None: @@ -1198,7 +1176,6 @@ def test_to_dict_golden_shape(self) -> None: "targeted_harm_categories", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", "scorer_identifier", "scores", ] @@ -1212,7 +1189,6 @@ def test_to_dict_golden_shape(self) -> None: assert d["targeted_harm_categories"] == [] assert d["prompt_metadata"] == {} assert d["converter_identifiers"] == [] - assert d["prompt_target_identifier"] is None assert d["scorer_identifier"] is None assert d["original_value_data_type"] == "text" assert d["original_value"] == "hello" diff --git a/tests/unit/prompt_converter/test_persuasion_converter.py b/tests/unit/prompt_converter/test_persuasion_converter.py index 02cd38301c..256eac967a 100644 --- a/tests/unit/prompt_converter/test_persuasion_converter.py +++ b/tests/unit/prompt_converter/test_persuasion_converter.py @@ -7,7 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import PersuasionConverter @@ -72,7 +72,6 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ) ] @@ -100,7 +99,6 @@ async def test_persuasion_converter_extracts_mutated_text(sqlite_instance): conversation_id="test-id", original_value='{"mutated_text": "rephrased prompt"}', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -122,7 +120,6 @@ async def test_persuasion_converter_missing_mutated_text_raises_invalid_json(sql conversation_id="test-id", original_value='{"other_key": "value"}', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -163,7 +160,6 @@ async def test_send_persuasion_prompt_async_emits_deprecation_warning_and_delega conversation_id="conv-1", original_value="test input", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) ] ) diff --git a/tests/unit/prompt_converter/test_translation_converter.py b/tests/unit/prompt_converter/test_translation_converter.py index 9073dea9ac..c5e1aa2702 100644 --- a/tests/unit/prompt_converter/test_translation_converter.py +++ b/tests/unit/prompt_converter/test_translation_converter.py @@ -7,7 +7,7 @@ import pytest from unit.mocks import MockPromptTarget -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import TranslationConverter @@ -40,7 +40,6 @@ async def test_translation_converter_returns_stripped_response(sqlite_instance): conversation_id="test-id", original_value=" hola \n", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -72,7 +71,6 @@ async def test_translation_converter_user_prompt_byte_for_byte_equivalent(sqlite conversation_id="test-id", original_value="hola", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -119,7 +117,6 @@ async def test_translation_converter_succeeds_after_retries(sqlite_instance): converted_value="hola", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test-identifier", class_module="test"), sequence=1, ) ] diff --git a/tests/unit/prompt_converter/test_variation_converter.py b/tests/unit/prompt_converter/test_variation_converter.py index 1357894a2a..e62d880f9d 100644 --- a/tests/unit/prompt_converter/test_variation_converter.py +++ b/tests/unit/prompt_converter/test_variation_converter.py @@ -7,7 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import VariationConverter @@ -44,7 +44,6 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ) ] @@ -71,7 +70,6 @@ async def test_variation_converter_extracts_first_element_from_json_list(sqlite_ conversation_id="test-id", original_value='["first variation", "second variation"]', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -92,7 +90,6 @@ async def test_variation_converter_preserves_original_and_converted_values(sqlit conversation_id="test-id", original_value='["variation"]', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -127,7 +124,6 @@ async def test_send_variation_prompt_async_emits_deprecation_warning_and_delegat conversation_id="conv-1", original_value="test input", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) ] ) diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 899a75b1d9..55a4a4f818 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -844,4 +844,5 @@ async def test_add_prepended_conversation_to_memory_emits_deprecation_warning_an converter_configurations=None, attack_identifier=None, prepended_conversation=None, + target_identifier=None, ) diff --git a/tests/unit/prompt_target/target/test_http_target.py b/tests/unit/prompt_target/target/test_http_target.py index e31fa005af..aef915d8b2 100644 --- a/tests/unit/prompt_target/target/test_http_target.py +++ b/tests/unit/prompt_target/target/test_http_target.py @@ -70,7 +70,6 @@ async def test_send_prompt_async(mock_request, mock_http_target, mock_http_respo MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -125,7 +124,6 @@ async def test_send_prompt_async_client_kwargs(patch_central_database): MagicMock( converted_value="", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -164,7 +162,6 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -200,7 +197,6 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -228,7 +224,6 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http MagicMock( converted_value="second_test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -285,7 +280,6 @@ async def test_http_target_with_injected_client(patch_central_database): MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 5a97782ce6..3b14a24169 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -16,7 +16,7 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.message_normalizer import GenericSystemSquashNormalizer -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_target import AzureMLChatTarget, OpenAIChatTarget from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, @@ -36,7 +36,6 @@ def _make_message_piece(*, role: str, content: str, conversation_id: str = "conv converted_value=content, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 6da03a1311..e32ce00100 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -25,7 +25,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target import ( OpenAIChatAudioConfig, @@ -283,7 +283,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -293,7 +292,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] @@ -391,7 +389,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -401,7 +398,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] @@ -454,7 +450,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -464,7 +459,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py index 97c71618c1..451e353ece 100644 --- a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py +++ b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py @@ -314,7 +314,7 @@ async def test_run_async_swaps_user_audio_and_records_identifiers_when_request_c persisted_user_messages: list[Message] = [] - async def _capture(*, message: Message) -> None: + async def _capture(*, message: Message, target_identifier=None) -> None: if message.message_pieces[0].api_role == "user": persisted_user_messages.append(message) @@ -349,7 +349,7 @@ async def test_run_async_skips_swap_and_identifiers_when_no_request_converters() persisted_user_messages: list[Message] = [] - async def _capture(*, message: Message) -> None: + async def _capture(*, message: Message, target_identifier=None) -> None: if message.message_pieces[0].api_role == "user": persisted_user_messages.append(message) diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index da758724b4..6a9cb9952f 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -23,7 +23,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target import OpenAIResponseTarget, PromptTarget @@ -306,7 +306,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -316,7 +315,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] @@ -397,7 +395,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -407,7 +404,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] @@ -441,7 +437,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -451,7 +446,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 8b89fcb175..161070de85 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -161,7 +161,6 @@ async def test_send_prompt_async_with_delay( _LINEAGE_CONVERSATION_ID = "original-conv-id-12345" _LINEAGE_LABELS = {"op_name": "test_op", "user_id": "user42"} -_LINEAGE_PROMPT_TARGET_IDENTIFIER = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit") _LINEAGE_PROMPT_METADATA = {"scenario": "test_scenario", "turn": 3} @@ -174,7 +173,6 @@ def _make_lineage_piece(*, role: str, content: str) -> MessagePiece: original_value_data_type="text", converted_value_data_type="text", labels=dict(_LINEAGE_LABELS), - prompt_target_identifier=_LINEAGE_PROMPT_TARGET_IDENTIFIER, prompt_metadata=dict(_LINEAGE_PROMPT_METADATA), ) @@ -237,7 +235,6 @@ async def test_history_squash_preserves_metadata_on_normalized_message(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -285,7 +282,6 @@ async def test_response_preserves_metadata_after_history_squash(): assert response_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert response_piece.labels == _LINEAGE_LABELS - assert response_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert response_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -331,7 +327,6 @@ async def test_system_squash_preserves_metadata(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -381,7 +376,6 @@ async def test_history_squash_propagates_lineage_to_all_pieces(): for piece in normalized[0].message_pieces: assert piece.conversation_id == _LINEAGE_CONVERSATION_ID assert piece.labels == _LINEAGE_LABELS - assert piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -444,7 +438,6 @@ async def test_conversation_id_stamped_on_all_but_full_lineage_only_on_last(): # Last message should carry full lineage. last_piece = normalized[-1].message_pieces[0] assert last_piece.labels == _LINEAGE_LABELS - assert last_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert last_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA # Warning should fire because message count increased (2 → 3). diff --git a/tests/unit/prompt_target/test_round_robin_target.py b/tests/unit/prompt_target/test_round_robin_target.py index 6218ea5013..bb2bf55bf7 100644 --- a/tests/unit/prompt_target/test_round_robin_target.py +++ b/tests/unit/prompt_target/test_round_robin_target.py @@ -373,12 +373,10 @@ async def test_full_send_prompt_async_keeps_round_robin_identifier(): for piece in message.message_pieces: piece.conversation_id = conv_id # Simulate what PromptNormalizer does - piece.prompt_target_identifier = rr.get_identifier() responses = await rr.send_prompt_async(message=message) # The request should still have the round-robin's identifier - assert message.message_pieces[0].prompt_target_identifier == rr.get_identifier() # Only t1 should have received the prompt (first in rotation) assert t1.prompt_sent == ["end to end test"] diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index c326f8bf15..8c6c3ff71d 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -251,7 +251,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data original_value="Response", conversation_id=conversation_id, labels={"test": "label"}, - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) @@ -287,7 +286,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data assert called_piece.id == message_piece.id assert called_piece.conversation_id == message_piece.conversation_id assert called_piece.labels == message_piece.labels - assert called_piece.prompt_target_identifier == message_piece.prompt_target_identifier async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(patch_central_database): From 8e3c5494adc4d1a374ee35647e11dd40784f822f Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sun, 7 Jun 2026 09:24:58 -0700 Subject: [PATCH 03/12] Fix ConversationEntry docstring and add None-clobber upsert test Correct the stale ConversationEntry docstring (target metadata is read via get_conversation_metadata, not rehydrated onto pieces) and add a regression test asserting a None target_identifier on a later write does not overwrite the target already recorded for a conversation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/memory_models.py | 3 ++- tests/unit/memory/test_sqlite_memory.py | 31 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 6e62af9b11..afec28d01e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -376,7 +376,8 @@ class ConversationEntry(Base): Holds identifiers that belong to the conversation as a whole -- currently the target identifier -- so they are not duplicated onto every ``PromptMemoryEntry`` row. The target is captured once when the conversation's pieces are written and - rehydrated onto pieces on read. + read back via ``MemoryInterface.get_conversation_metadata`` (it is not stamped + onto individual pieces). """ __tablename__ = "Conversations" diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 95c05d40fc..c22f6a0bca 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -557,6 +557,37 @@ def test_get_memories_with_json_properties(sqlite_instance): assert labels["normalizer_id"] == "id1" +def test_capture_conversation_none_target_does_not_clobber(sqlite_instance): + # A conversation is held with a single target. The request piece records the + # target; a later write for the same conversation that has no target (e.g. a + # response or branched copy) must NOT overwrite the recorded target with None. + conversation_id = "conv-none-clobber" + target = TextTarget() + + request_piece = MessagePiece( + conversation_id=conversation_id, + role="user", + sequence=1, + original_value="hello", + ) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[request_piece], target_identifier=target.get_identifier() + ) + + response_piece = MessagePiece( + conversation_id=conversation_id, + role="assistant", + sequence=2, + original_value="world", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece], target_identifier=None) + + metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + assert metadata is not None + assert metadata.target_identifier is not None + assert metadata.target_identifier.class_name == "TextTarget" + + def test_update_entries(sqlite_instance): # Insert a test entry entry = PromptMemoryEntry( From 3ad8d7544ab8898849ab05481c6ecda062476639 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 8 Jun 2026 18:47:22 -0700 Subject: [PATCH 04/12] Make conversation_id caller-owned and harden memory choke point Implements two approved PR review comments: - MessagePiece.conversation_id is now str | None (no auto-UUID default); the caller/normalizer owns id generation. - add_message_pieces_to_memory is a concrete Template Method; subclasses implement only _add_message_pieces_to_storage. A persistable piece with no conversation_id now fails loud instead of being silently assigned one. Blast-radius fixes: scorer-derived pieces (audio/video) copy the source conversation_id, human-labeled dataset rows get a fresh id per single-turn conversation, and CSV import raises a clear error on a missing id. Also restores the atomic attack identifier in the fairness/bias fallback. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/benchmark/fairness_bias.py | 2 + pyrit/memory/azure_sql_memory.py | 34 +++------ pyrit/memory/memory_interface.py | 73 ++++++++++++++++++- pyrit/memory/sqlite_memory.py | 25 ++----- pyrit/models/messages/message_piece.py | 7 +- pyrit/prompt_target/text_target.py | 14 +++- pyrit/score/audio_transcript_scorer.py | 1 + .../human_labeled_dataset.py | 2 + pyrit/score/video_scorer.py | 2 + .../test_interface_prompts.py | 65 ++++++++++++++++- .../memory_interface/test_interface_scores.py | 2 + .../target/test_prompt_target_text.py | 26 +++++++ tests/unit/score/test_audio_scorer.py | 1 + tests/unit/score/test_self_ask_refusal.py | 8 +- 14 files changed, 213 insertions(+), 49 deletions(-) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 4ab0fe432a..4e3bfaa505 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -21,6 +21,7 @@ from pyrit.models import ( AttackOutcome, AttackResult, + ComponentIdentifier, Message, build_atomic_attack_identifier, ) @@ -197,6 +198,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta objective=context.generated_objective, outcome=AttackOutcome.FAILURE, atomic_attack_identifier=build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier.of(self), ), labels=context.memory_labels, ) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index f3bfe8e126..f7c21ed651 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -27,7 +27,7 @@ PromptMemoryEntry, ) from pyrit.memory.storage import AzureBlobStorageIO -from pyrit.models import ComponentIdentifier, ConversationStats, MessagePiece +from pyrit.models import ConversationStats, MessagePiece if TYPE_CHECKING: from azure.core.credentials import AccessToken @@ -695,32 +695,20 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def add_message_pieces_to_memory( - self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None - ) -> None: + def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ - Insert a list of message pieces into the memory storage. + Persist already-validated message pieces to the Azure SQL store. - Pieces flagged via ``MessagePiece.not_in_memory = True`` are - silently filtered out so callers don't need to track persistence policy - themselves. + ``not_in_memory`` pieces are ephemeral -- typically synthesized inside a + scorer to score arbitrary content that never came through a real + PromptTarget. They are filtered out upstream in + ``add_message_pieces_to_memory`` before this method is called. Args: - message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added. - target_identifier (ComponentIdentifier | None): The target the conversation(s) - are held with, if known. Applied to every distinct ``conversation_id``. - """ - # ``not_in_memory`` pieces are ephemeral — typically synthesized inside a - # scorer to score arbitrary content that never came through a real - # PromptTarget. They have no conversation, target, or attack lineage, so - # persisting them would pollute the memory store with rows that don't - # tie to any real exchange. Filtering here lets every caller share one - # policy instead of guarding each call site. - pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] - if not pieces_to_insert: - return - self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) - self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) + message_pieces (Sequence[MessagePiece]): Persistable pieces (filtered and + validated by ``add_message_pieces_to_memory``). + """ + self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) def dispose_engine(self) -> None: """ diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a1c2219cd4..66ba5ab351 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -349,19 +349,75 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> An Any: A SQLAlchemy condition for filtering memory entries based on prompt metadata. """ - @abc.abstractmethod def add_message_pieces_to_memory( self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None ) -> None: """ Insert a list of message pieces into the memory storage. + Pieces flagged via ``MessagePiece.not_in_memory = True`` are silently filtered + out so callers don't need to track persistence policy themselves. Every + remaining piece must carry a non-empty ``conversation_id`` (the memory layer + never invents one -- see ``_validate_persistable_conversation_ids``). The + conversation-scoped metadata row is captured once per ``conversation_id`` via + ``_capture_conversations`` before the storage-specific insert. + + This is a template method: subclasses implement only the backend-specific + ``_add_message_pieces_to_storage`` and inherit the filtering, validation, and + conversation-capture steps so no subclass can forget to run them. + Args: message_pieces (Sequence[MessagePiece]): The pieces to persist. target_identifier (ComponentIdentifier | None): The target the conversation(s) are held with, if known. A conversation is always with a single target, so this is applied to every distinct ``conversation_id`` in ``message_pieces``. """ + pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] + if not pieces_to_insert: + return + self._validate_persistable_conversation_ids(message_pieces=pieces_to_insert) + self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) + self._add_message_pieces_to_storage(message_pieces=pieces_to_insert) + + @abc.abstractmethod + def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: + """ + Persist already-validated message pieces to the backing store. + + Called by ``add_message_pieces_to_memory`` after ``not_in_memory`` pieces are + filtered out, conversation_ids are validated, and the ``Conversations`` rows are + captured. Implementations only translate the pieces into storage rows and insert + them; they must not re-filter or re-validate. + + Args: + message_pieces (Sequence[MessagePiece]): Persistable pieces (none flagged + ``not_in_memory``), each carrying a non-empty ``conversation_id``. + """ + + @staticmethod + def _validate_persistable_conversation_ids(*, message_pieces: Sequence[MessagePiece]) -> None: + """ + Ensure every persistable piece carries a usable ``conversation_id``. + + A conversation is its own entity, so the caller that starts it owns the id; the + memory layer never generates one. Any piece reaching persistence without a + non-empty, non-blank ``conversation_id`` is a programming error and raises loudly + rather than being silently assigned a throwaway conversation. + + Args: + message_pieces (Sequence[MessagePiece]): Pieces about to be persisted + (``not_in_memory`` pieces should already be filtered out). + + Raises: + ValueError: If any piece has a ``None``, empty, or whitespace-only + ``conversation_id``. + """ + for piece in message_pieces: + if piece.conversation_id is None or not piece.conversation_id.strip(): + raise ValueError( + f"MessagePiece {piece.id} has no conversation_id. A conversation_id must be set by " + "the caller before a piece is persisted; the memory layer does not generate one." + ) def _capture_conversations( self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None @@ -391,6 +447,8 @@ def _capture_conversations( if piece.not_in_memory: continue conversation_id = piece.conversation_id + if not conversation_id: + continue if conversation_id not in seen: seen.add(conversation_id) conversation_ids.append(conversation_id) @@ -413,10 +471,15 @@ def _upsert_conversation( is held with, if known. Raises: + ValueError: If ``conversation_id`` is empty (a piece reached persistence + without a caller-assigned conversation_id; callers must set one). SQLAlchemyError: If the upsert fails. """ if not conversation_id: - return + raise ValueError( + "Cannot upsert a Conversations row without a conversation_id. This indicates a message " + "piece reached persistence without a caller-assigned conversation_id." + ) entry = ConversationEntry( conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) ) @@ -1254,6 +1317,12 @@ def add_message_to_memory( embedding_entries = [] message_pieces = request.message_pieces + pieces_to_persist = [piece for piece in message_pieces if not piece.not_in_memory] + if not pieces_to_persist: + return + + self._validate_persistable_conversation_ids(message_pieces=pieces_to_persist) + self._update_sequence(message_pieces=message_pieces) self.add_message_pieces_to_memory(message_pieces=message_pieces, target_identifier=target_identifier) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 17c0f9485b..3fcb28326c 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -30,7 +30,7 @@ ScenarioResultEntry, ) from pyrit.memory.storage import DiskStorageIO -from pyrit.models import ComponentIdentifier, ConversationStats, MessagePiece +from pyrit.models import ConversationStats, MessagePiece logger = logging.getLogger(__name__) @@ -302,26 +302,15 @@ def _get_condition_json_array_match( combined = joiner.join(conditions) return text(f"({combined})").bindparams(**bindparams_dict) - def add_message_pieces_to_memory( - self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None - ) -> None: + def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ - Insert a list of message pieces into the memory storage. - - Pieces flagged via ``MessagePiece.not_in_memory = True`` are - silently filtered out so callers don't need to track persistence policy - themselves. + Persist already-validated message pieces to the SQLite store. Args: - message_pieces (Sequence[MessagePiece]): The pieces to persist. - target_identifier (ComponentIdentifier | None): The target the conversation(s) - are held with, if known. Applied to every distinct ``conversation_id``. - """ - pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] - if not pieces_to_insert: - return - self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) - self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) + message_pieces (Sequence[MessagePiece]): Persistable pieces (filtered and + validated by ``add_message_pieces_to_memory``). + """ + self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index da097b129c..b77e77ca46 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -96,7 +96,7 @@ class MessagePiece(BaseModel): id: uuid.UUID = Field(default_factory=uuid4) role: ChatMessageRole - conversation_id: str = Field(default_factory=lambda: str(uuid4())) + conversation_id: str | None = None sequence: int = -1 timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) original_value: str @@ -333,4 +333,7 @@ def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece convo_id: min(x.timestamp for x in message_pieces if x.conversation_id == convo_id) for convo_id in {x.conversation_id for x in message_pieces} } - return sorted(message_pieces, key=lambda x: (earliest_timestamps[x.conversation_id], x.conversation_id, x.sequence)) + return sorted( + message_pieces, + key=lambda x: (earliest_timestamps[x.conversation_id], x.conversation_id or "", x.sequence), + ) diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index ba69a8dcdb..5e24e107a0 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -67,22 +67,32 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: Returns: list[MessagePiece]: A list of message pieces imported from the CSV. + + Raises: + ValueError: If a row is missing a ``conversation_id``. """ message_pieces = [] with open(csv_file_path, newline="") as csvfile: csvreader = csv.DictReader(csvfile) - for row in csvreader: + for row_number, row in enumerate(csvreader, start=1): sequence_str = row.get("sequence", None) labels_str = row.get("labels", None) labels = json.loads(labels_str) if labels_str else None + conversation_id = row.get("conversation_id", None) + if not conversation_id or not conversation_id.strip(): + raise ValueError( + f"Row {row_number} of '{csv_file_path}' is missing a 'conversation_id'. " + "Every imported row must specify the conversation it belongs to." + ) + message_piece = MessagePiece( role=row["role"], original_value=row["value"], original_value_data_type=row.get("data_type", None), - conversation_id=row.get("conversation_id", None), + conversation_id=conversation_id, sequence=int(sequence_str) if sequence_str else 0, labels=labels, # deprecated response_error=row.get("response_error", None), diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 17bf2af6c0..707766386b 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -192,6 +192,7 @@ async def _score_audio_async(self, *, message_piece: MessagePiece, objective: st original_prompt_id=original_prompt_id, converted_value=transcript, converted_value_data_type="text", + conversation_id=message_piece.conversation_id, ) text_message = text_piece.to_message() diff --git a/pyrit/score/scorer_evaluation/human_labeled_dataset.py b/pyrit/score/scorer_evaluation/human_labeled_dataset.py index 937573d840..5856a40aed 100644 --- a/pyrit/score/scorer_evaluation/human_labeled_dataset.py +++ b/pyrit/score/scorer_evaluation/human_labeled_dataset.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, cast +from uuid import uuid4 import pandas as pd @@ -318,6 +319,7 @@ def from_csv( role="assistant", original_value=response_to_score, original_value_data_type=cast("PromptDataType", data_type), + conversation_id=str(uuid4()), ) ], ) diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index ea69978357..bc7aec92cc 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -133,6 +133,7 @@ async def _score_frames_async(self, *, message_piece: MessagePiece, objective: s original_prompt_id=original_prompt_id, converted_value=frame, converted_value_data_type="image_path", + conversation_id=message_piece.conversation_id, ) response = piece.to_message() image_requests.append(response) @@ -248,6 +249,7 @@ async def _score_video_audio_async( original_prompt_id=original_prompt_id, converted_value=audio_path, converted_value_data_type="audio_path", + conversation_id=message_piece.conversation_id, ) audio_message = audio_piece.to_message() diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index ad5cd1ddb6..973a88234c 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -70,18 +70,21 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface pieces = [ MessagePiece( + conversation_id=str(uuid4()), id=uuid1, role="user", original_value="Test prompt 1", converted_value="Test prompt 1", ), MessagePiece( + conversation_id=str(uuid4()), id=uuid2, role="assistant", original_value="Test prompt 2", converted_value="Test prompt 2", ), MessagePiece( + conversation_id=str(uuid4()), id=uuid3, role="user", original_value="Test prompt 3", @@ -114,6 +117,7 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface def test_get_message_pieces_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface): piece = MessagePiece( + conversation_id=str(uuid4()), id=uuid.uuid4(), role="user", original_value="Test prompt", @@ -547,7 +551,7 @@ def test_duplicate_conversation_with_multiple_pieces(sqlite_instance: MemoryInte def test_add_message_pieces_to_memory_calls_validate(sqlite_instance: MemoryInterface): message = MagicMock(Message) - message.message_pieces = [MagicMock(MessagePiece)] + message.message_pieces = [MagicMock(MessagePiece, not_in_memory=False, conversation_id="test-conversation")] with ( patch("pyrit.memory.sqlite_memory.SQLiteMemory.add_message_pieces_to_memory"), patch("pyrit.memory.memory_interface.MemoryInterface._update_sequence"), @@ -556,6 +560,35 @@ def test_add_message_pieces_to_memory_calls_validate(sqlite_instance: MemoryInte assert message.validate.called +@pytest.mark.parametrize("bad_id", [None, "", " "]) +def test_add_message_pieces_to_memory_raises_when_conversation_id_missing( + sqlite_instance: MemoryInterface, bad_id +): + piece = MessagePiece(role="user", original_value="hello", conversation_id=bad_id) + with pytest.raises(ValueError, match="conversation_id"): + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + + +@pytest.mark.parametrize("bad_id", [None, "", " "]) +def test_add_message_to_memory_raises_when_conversation_id_missing(sqlite_instance: MemoryInterface, bad_id): + piece = MessagePiece(role="user", original_value="hello", conversation_id=bad_id) + with pytest.raises(ValueError, match="conversation_id"): + sqlite_instance.add_message_to_memory(request=Message(message_pieces=[piece])) + + +def test_add_message_pieces_to_memory_skips_not_in_memory_without_conversation_id( + sqlite_instance: MemoryInterface, +): + # not_in_memory pieces are filtered out before persistence, so a missing + # conversation_id on an ephemeral piece must not raise. + ephemeral = MessagePiece(role="user", original_value="ephemeral", conversation_id=None) + ephemeral.not_in_memory = True + + sqlite_instance.add_message_pieces_to_memory(message_pieces=[ephemeral]) + + assert sqlite_instance.get_message_pieces() == [] + + def test_add_message_pieces_to_memory_updates_sequence( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): @@ -638,6 +671,7 @@ def test_get_message_pieces_labels(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", labels=labels, @@ -645,6 +679,7 @@ def test_get_message_pieces_labels(sqlite_instance: MemoryInterface): ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", labels=labels, @@ -652,6 +687,7 @@ def test_get_message_pieces_labels(sqlite_instance: MemoryInterface): ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -711,6 +747,7 @@ def test_get_message_pieces_labels_returns_pme_and_ar_label_matches(sqlite_insta # PME with direct labels pme_direct = PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Direct label", labels=labels, @@ -736,6 +773,7 @@ def test_get_message_pieces_labels_returns_pme_and_ar_label_matches(sqlite_insta # PME with no labels and no matching AR pme_no_match = PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="No match", ) @@ -781,6 +819,7 @@ def test_get_message_pieces_metadata(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", prompt_metadata=metadata, @@ -788,6 +827,7 @@ def test_get_message_pieces_metadata(sqlite_instance: MemoryInterface): ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", prompt_metadata={"key2": "value2", "key3": "value3"}, @@ -795,6 +835,7 @@ def test_get_message_pieces_metadata(sqlite_instance: MemoryInterface): ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -814,18 +855,21 @@ def test_get_message_pieces_id(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -885,18 +929,21 @@ def test_get_message_pieces_sent_after(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -918,18 +965,21 @@ def test_get_message_pieces_sent_before(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -952,18 +1002,21 @@ def test_get_message_pieces_by_value(sqlite_instance: MemoryInterface): entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ) ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ) @@ -981,14 +1034,17 @@ def test_get_message_pieces_by_value(sqlite_instance: MemoryInterface): def test_get_message_pieces_by_hash(sqlite_instance: MemoryInterface): entries = [ MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ), MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ), MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 3", ), @@ -1075,11 +1131,13 @@ def test_message_piece_scores_duplicate_piece(sqlite_instance: MemoryInterface): pieces = [ MessagePiece( + conversation_id=str(uuid4()), id=original_id, role="assistant", original_value="prompt text", ), MessagePiece( + conversation_id=str(uuid4()), id=duplicate_id, role="assistant", original_value="prompt text", @@ -1114,10 +1172,12 @@ def test_message_piece_scores_duplicate_piece(sqlite_instance: MemoryInterface): async def test_message_piece_hash_stored_and_retrieved(sqlite_instance: MemoryInterface): entries = [ MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="Hello 1", ), MessagePiece( + conversation_id=str(uuid4()), role="assistant", original_value="Hello 2", ), @@ -1391,6 +1451,7 @@ def test_get_message_pieces_by_converter_identifier_filter_with_array_element_pa entries = [ PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="With Base64", converter_identifiers=[converter_a], @@ -1398,6 +1459,7 @@ def test_get_message_pieces_by_converter_identifier_filter_with_array_element_pa ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="With both converters", converter_identifiers=[converter_a, converter_b], @@ -1405,6 +1467,7 @@ def test_get_message_pieces_by_converter_identifier_filter_with_array_element_pa ), PromptMemoryEntry( entry=MessagePiece( + conversation_id=str(uuid4()), role="user", original_value="No converters", ) diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index ead87d6666..ed12d309c8 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -151,6 +151,7 @@ def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: Memor role="user", original_value="original prompt text", converted_value="Hello, how are you?", + conversation_id=str(uuid4()), ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) @@ -221,6 +222,7 @@ def test_get_scores_by_memory_labels(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", sequence=0, labels={"sample": "label"}, + conversation_id=str(uuid4()), ) ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) diff --git a/tests/unit/prompt_target/target/test_prompt_target_text.py b/tests/unit/prompt_target/target/test_prompt_target_text.py index 93c81f2b8e..66b21d7c5f 100644 --- a/tests/unit/prompt_target/target/test_prompt_target_text.py +++ b/tests/unit/prompt_target/target/test_prompt_target_text.py @@ -4,6 +4,7 @@ import io import os from collections.abc import MutableSequence +from pathlib import Path from tempfile import NamedTemporaryFile import pytest @@ -52,3 +53,28 @@ async def test_send_prompt_stream(sample_entries: MutableSequence[MessagePiece]) os.remove(tmp_file.name) assert prompt in content, "The prompt was not found in the temporary file content." + + +@pytest.mark.usefixtures("patch_central_database") +def test_import_scores_from_csv_missing_conversation_id_raises(tmp_path: Path): + csv_path = tmp_path / "scores.csv" + csv_path.write_text("role,value\nassistant,hello\n", encoding="utf-8") + + no_op = TextTarget() + with pytest.raises(ValueError, match="conversation_id"): + no_op.import_scores_from_csv(csv_file_path=csv_path) + + +@pytest.mark.usefixtures("patch_central_database") +def test_import_scores_from_csv_with_conversation_id_succeeds(tmp_path: Path): + csv_path = tmp_path / "scores.csv" + csv_path.write_text( + "role,value,data_type,response_error,labels,conversation_id\nassistant,hello,text,none,{},conv-1\n", + encoding="utf-8", + ) + + no_op = TextTarget() + pieces = no_op.import_scores_from_csv(csv_file_path=csv_path) + + assert len(pieces) == 1 + assert pieces[0].conversation_id == "conv-1" diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 88b43f4cd2..c063493091 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -85,6 +85,7 @@ def audio_message_piece(patch_central_database): converted_value=audio_path, original_value_data_type="audio_path", converted_value_data_type="audio_path", + conversation_id=str(uuid.uuid4()), ) message_piece.id = uuid.uuid4() diff --git a/tests/unit/score/test_self_ask_refusal.py b/tests/unit/score/test_self_ask_refusal.py index abb8e3f557..2872de3394 100644 --- a/tests/unit/score/test_self_ask_refusal.py +++ b/tests/unit/score/test_self_ask_refusal.py @@ -5,6 +5,7 @@ from pathlib import Path from textwrap import dedent from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 import pytest from unit.mocks import get_mock_target_identifier @@ -161,7 +162,12 @@ async def test_score_async_filtered_response(patch_central_database): chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskRefusalScorer(chat_target=chat_target) - request = MessagePiece(role="assistant", original_value="blocked response", response_error="blocked").to_message() + request = MessagePiece( + role="assistant", + original_value="blocked response", + response_error="blocked", + conversation_id=str(uuid4()), + ).to_message() memory.add_message_pieces_to_memory(message_pieces=request.message_pieces) scores = await scorer.score_async(request) From cec9ccdda369bd39123a431daab9a705934674d9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 8 Jun 2026 19:39:29 -0700 Subject: [PATCH 05/12] Add explicit conversation registration; drop target_identifier from message writes Introduce MemoryInterface.add_conversation_to_memory(*, conversation_id, target_identifier=None) so a conversation's target is registered once where the conversation is created. Low-level writes (add_message_to_memory, add_message_pieces_to_memory) no longer accept target_identifier and no longer touch the Conversations table; the unused _capture_conversations choke point is removed. Conversation-level helpers (normalizer/conversation_manager prepended flows, backend _store_* helpers, duplicate helpers) keep target_identifier and register once internally. A Conversations row is now created only on explicit registration rather than on every write; all readers already tolerate a missing/None target. Also remove a redundant second conversation_id validation in add_message_to_memory since every write funnels through add_message_pieces_to_memory. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 21 ++-- .../attack/component/conversation_manager.py | 4 +- .../multi_turn/multi_turn_attack_strategy.py | 5 +- .../executor/attack/multi_turn/red_teaming.py | 8 +- .../attack/multi_turn/tree_of_attacks.py | 5 +- pyrit/memory/memory_interface.py | 116 +++++++----------- pyrit/prompt_normalizer/prompt_normalizer.py | 28 ++--- .../common/discover_target_capabilities.py | 14 +-- pyrit/prompt_target/common/prompt_target.py | 4 +- .../_openai_realtime_streaming_session.py | 11 +- pyrit/prompt_target/text_target.py | 8 +- .../test_interface_prompts.py | 53 +++++++- tests/unit/memory/test_azure_sql_memory.py | 5 +- tests/unit/memory/test_sqlite_memory.py | 21 ++-- 14 files changed, 168 insertions(+), 135 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index f6cea27567..fe3b8f527d 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -477,9 +477,7 @@ async def create_related_conversation_async( # --- Branch via duplication (preferred for tracking) --------------- if request.source_conversation_id is not None and request.cutoff_index is not None: - source_metadata = self._memory.get_conversation_metadata( - conversation_id=request.source_conversation_id - ) + source_metadata = self._memory.get_conversation_metadata(conversation_id=request.source_conversation_id) new_conversation_id = self._duplicate_conversation_up_to( source_conversation_id=request.source_conversation_id, cutoff_index=request.cutoff_index, @@ -877,9 +875,10 @@ def _duplicate_conversation_up_to( piece.role = "simulated_assistant" if all_pieces: - self._memory.add_message_pieces_to_memory( - message_pieces=list(all_pieces), target_identifier=target_identifier + self._memory.add_conversation_to_memory( + conversation_id=new_conversation_id, target_identifier=target_identifier ) + self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces)) return new_conversation_id @@ -970,6 +969,9 @@ async def _store_prepended_messages_async( target_identifier: ComponentIdentifier | None = None, ) -> None: """Store prepended conversation messages in memory.""" + if not prepended: + return + self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) for seq, msg in enumerate(prepended): for p in msg.pieces: piece = request_piece_to_pyrit_message_piece( @@ -979,9 +981,7 @@ async def _store_prepended_messages_async( sequence=seq, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory( - message_pieces=[piece], target_identifier=target_identifier - ) + self._memory.add_message_pieces_to_memory(message_pieces=[piece]) async def _send_and_store_message_async( self, @@ -1031,6 +1031,7 @@ async def _store_message_only_async( ) -> None: """Store message without sending (send=False).""" await self._persist_base64_pieces_async(request) + self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) for p in request.pieces: piece = request_piece_to_pyrit_message_piece( piece=p, @@ -1039,9 +1040,7 @@ async def _store_message_only_async( sequence=sequence, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory( - message_pieces=[piece], target_identifier=target_identifier - ) + self._memory.add_message_pieces_to_memory(message_pieces=[piece]) def _resolve_video_remix_metadata(self, request: AddMessageRequest) -> None: """ diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 8a3cbac557..30b9305dac 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -465,6 +465,8 @@ async def add_prepended_conversation_to_memory_async( if not valid_messages: return 0 + self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + # Get roles that should have converters applied apply_to_roles = ( prepended_conversation_config.apply_converters_to_roles if prepended_conversation_config else None @@ -499,7 +501,7 @@ async def add_prepended_conversation_to_memory_async( ) # Add to memory - self._memory.add_message_to_memory(request=message_copy, target_identifier=target_identifier) + self._memory.add_message_to_memory(request=message_copy) logger.debug(f"Added prepended message {i + 1}/{len(valid_messages)} to memory") return turn_count diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index bb3b07cef2..72db4d75a1 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -141,9 +141,10 @@ def _rotate_conversation_for_single_turn_target( if system_messages: new_conversation_id, pieces = memory.duplicate_messages(messages=system_messages) - memory.add_message_pieces_to_memory( - message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + memory.add_conversation_to_memory( + conversation_id=new_conversation_id, target_identifier=self._objective_target.get_identifier() ) + memory.add_message_pieces_to_memory(message_pieces=pieces) context.session.conversation_id = new_conversation_id else: context.session.conversation_id = str(uuid.uuid4()) diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 08c18fc4c4..4c47fced33 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -272,10 +272,12 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: labels=context.memory_labels, ) + self._memory.add_conversation_to_memory( + conversation_id=context.session.adversarial_chat_conversation_id, + target_identifier=self._adversarial_chat.get_identifier(), + ) for msg in adversarial_messages: - self._memory.add_message_to_memory( - request=msg, target_identifier=self._adversarial_chat.get_identifier() - ) + self._memory.add_message_to_memory(request=msg) async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> AttackResult: """ diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 22d0541e9d..46909a8384 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -821,9 +821,10 @@ def duplicate(self) -> "_TreeOfAttacksNode": system_messages = [m for m in messages if m.api_role == "system"] if system_messages: new_id, pieces = self._memory.duplicate_messages(messages=system_messages) - self._memory.add_message_pieces_to_memory( - message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + self._memory.add_conversation_to_memory( + conversation_id=new_id, target_identifier=self._objective_target.get_identifier() ) + self._memory.add_message_pieces_to_memory(message_pieces=pieces) duplicate_node.objective_target_conversation_id = new_id else: duplicate_node.objective_target_conversation_id = str(uuid.uuid4()) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 66ba5ab351..410db0b273 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -349,34 +349,53 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> An Any: A SQLAlchemy condition for filtering memory entries based on prompt metadata. """ - def add_message_pieces_to_memory( - self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + def add_conversation_to_memory( + self, *, conversation_id: str, target_identifier: ComponentIdentifier | None = None ) -> None: + """ + Register a conversation in memory, recording its conversation-scoped metadata. + + A conversation is a first-class entity held with a single target. Call this once + when a conversation is created (before, or independently of, adding its messages) + to record the target it is held with. Message writes (``add_message_to_memory`` / + ``add_message_pieces_to_memory``) deliberately do not take a target, so that + conversation ownership is expressed in a single place rather than threaded through + every write. + + Registration is idempotent: a non-``None`` ``target_identifier`` is recorded, and + a ``None`` value never overwrites a target already recorded for the conversation. + + Args: + conversation_id (str): The caller-owned conversation identifier. + target_identifier (ComponentIdentifier | None): The target the conversation is + held with, if known. + """ + self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) + + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. Pieces flagged via ``MessagePiece.not_in_memory = True`` are silently filtered out so callers don't need to track persistence policy themselves. Every remaining piece must carry a non-empty ``conversation_id`` (the memory layer - never invents one -- see ``_validate_persistable_conversation_ids``). The - conversation-scoped metadata row is captured once per ``conversation_id`` via - ``_capture_conversations`` before the storage-specific insert. + never invents one -- see ``_validate_persistable_conversation_ids``). + + Conversation-scoped metadata (the target a conversation is held with) is not + recorded here; register it once via ``add_conversation_to_memory`` when the + conversation is created. This is a template method: subclasses implement only the backend-specific - ``_add_message_pieces_to_storage`` and inherit the filtering, validation, and - conversation-capture steps so no subclass can forget to run them. + ``_add_message_pieces_to_storage`` and inherit the filtering and validation + steps so no subclass can forget to run them. Args: message_pieces (Sequence[MessagePiece]): The pieces to persist. - target_identifier (ComponentIdentifier | None): The target the conversation(s) - are held with, if known. A conversation is always with a single target, so - this is applied to every distinct ``conversation_id`` in ``message_pieces``. """ pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return self._validate_persistable_conversation_ids(message_pieces=pieces_to_insert) - self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) self._add_message_pieces_to_storage(message_pieces=pieces_to_insert) @abc.abstractmethod @@ -419,51 +438,13 @@ def _validate_persistable_conversation_ids(*, message_pieces: Sequence[MessagePi "the caller before a piece is persisted; the memory layer does not generate one." ) - def _capture_conversations( - self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None - ) -> None: - """ - Record one ``Conversations`` row per conversation for the given pieces. - - Conversation-scoped metadata (currently the target identifier) is persisted - once per ``conversation_id`` instead of being stamped onto every piece. This - runs from each backend's ``add_message_pieces_to_memory`` so every write path - -- normalizer, conversation duplication, prepended conversations, direct - target writers -- captures the target through a single choke point. - - A conversation is always held with a single target, so ``target_identifier`` - (when provided) is applied to every distinct ``conversation_id`` in this call. - A ``None`` target never overwrites a target already recorded for the - conversation (see ``_upsert_conversation``). - - Args: - message_pieces (Sequence[MessagePiece]): The pieces being persisted. - target_identifier (ComponentIdentifier | None): The target the conversation(s) - are held with, if known. - """ - conversation_ids: list[str] = [] - seen: set[str] = set() - for piece in message_pieces: - if piece.not_in_memory: - continue - conversation_id = piece.conversation_id - if not conversation_id: - continue - if conversation_id not in seen: - seen.add(conversation_id) - conversation_ids.append(conversation_id) - for conversation_id in conversation_ids: - self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) - - def _upsert_conversation( - self, *, conversation_id: str, target_identifier: ComponentIdentifier | None - ) -> None: + def _upsert_conversation(self, *, conversation_id: str, target_identifier: ComponentIdentifier | None) -> None: """ Insert or update the ``Conversations`` row for ``conversation_id``. A non-``None`` ``target_identifier`` is written; a ``None`` value never - overwrites a target already recorded for the conversation (so response/copy - pieces and write ordering cannot clobber it). + overwrites a target already recorded for the conversation (so re-registration + and copy/duplicate flows cannot clobber it). Args: conversation_id (str): The conversation to record. @@ -471,15 +452,11 @@ def _upsert_conversation( is held with, if known. Raises: - ValueError: If ``conversation_id`` is empty (a piece reached persistence - without a caller-assigned conversation_id; callers must set one). + ValueError: If ``conversation_id`` is empty. SQLAlchemyError: If the upsert fails. """ if not conversation_id: - raise ValueError( - "Cannot upsert a Conversations row without a conversation_id. This indicates a message " - "piece reached persistence without a caller-assigned conversation_id." - ) + raise ValueError("Cannot register a conversation without a conversation_id.") entry = ConversationEntry( conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) ) @@ -1188,7 +1165,9 @@ def get_message_pieces( if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if identifier_filters: - conditions.extend(self._build_message_piece_identifier_conditions(identifier_filters=identifier_filters)) + conditions.extend( + self._build_message_piece_identifier_conditions(identifier_filters=identifier_filters) + ) # Identify list parameters that may need batching list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] @@ -1259,7 +1238,9 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) - self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) + if all_pieces: + self.add_conversation_to_memory(conversation_id=new_conversation_id, target_identifier=source_target) + self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> str: @@ -1294,13 +1275,13 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) - self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) + if all_pieces: + self.add_conversation_to_memory(conversation_id=new_conversation_id, target_identifier=source_target) + self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id - def add_message_to_memory( - self, *, request: Message, target_identifier: ComponentIdentifier | None = None - ) -> None: + def add_message_to_memory(self, *, request: Message) -> None: """ Insert a list of message pieces into the memory storage. @@ -1309,8 +1290,6 @@ def add_message_to_memory( Args: request (Message): The message to add to the memory. - target_identifier (ComponentIdentifier | None): The target the conversation - is held with, if known. Forwarded to ``add_message_pieces_to_memory``. """ request.validate() @@ -1321,11 +1300,10 @@ def add_message_to_memory( if not pieces_to_persist: return - self._validate_persistable_conversation_ids(message_pieces=pieces_to_persist) - self._update_sequence(message_pieces=message_pieces) - self.add_message_pieces_to_memory(message_pieces=message_pieces, target_identifier=target_identifier) + # conversation_id validation happens in add_message_pieces_to_memory, the shared choke point. + self.add_message_pieces_to_memory(message_pieces=message_pieces) if self.memory_embedding: for piece in message_pieces: diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index c005e687ab..30ccc2a698 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -119,6 +119,7 @@ async def send_prompt_async( request = copy.deepcopy(message) conversation_id = conversation_id if conversation_id else str(uuid4()) target_identifier = target.get_identifier() + self.memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) for piece in request.message_pieces: piece.conversation_id = conversation_id @@ -134,10 +135,10 @@ async def send_prompt_async( try: responses = await target.send_prompt_async(message=request) - self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution - self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=request) responses = [ construct_response_from_request( @@ -150,7 +151,7 @@ async def send_prompt_async( except Exception as ex: # Ensure request to memory before processing exception - self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=request) error_response = construct_response_from_request( request=request.message_pieces[0], @@ -160,7 +161,7 @@ async def send_prompt_async( ) await self._calc_hash_async(request=error_response) - self.memory.add_message_to_memory(request=error_response, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=error_response) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex @@ -177,7 +178,7 @@ async def send_prompt_async( error="empty", ) await self._calc_hash_async(request=empty_response) - self.memory.add_message_to_memory(request=empty_response, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=empty_response) return empty_response # Process all response messages (targets return list[Message]) @@ -190,7 +191,7 @@ async def send_prompt_async( converter_configurations=response_converter_configurations, message=resp ) await self._calc_hash_async(request=resp) - self.memory.add_message_to_memory(request=resp, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=resp) # Return the last response for backward compatibility return responses[-1] @@ -384,22 +385,20 @@ async def _calc_hash_async(self, request: Message) -> None: tasks = [asyncio.create_task(set_message_piece_sha256_async(piece)) for piece in request.message_pieces] await asyncio.gather(*tasks) - async def hash_and_persist_message_async( - self, *, message: Message, target_identifier: ComponentIdentifier | None = None - ) -> None: + async def hash_and_persist_message_async(self, *, message: Message) -> None: """ Hash and persist a Message to memory. Use when a target assembles a Message outside the ``send_prompt_async`` flow - (e.g. streaming sessions that yield per-turn Messages directly). + (e.g. streaming sessions that yield per-turn Messages directly). Register the + conversation once via ``MemoryInterface.add_conversation_to_memory`` before + persisting its messages. Args: message (Message): The message to hash and persist. - target_identifier (ComponentIdentifier | None): The target the conversation - is held with, if known. """ await self._calc_hash_async(request=message) - self.memory.add_message_to_memory(request=message, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=message) async def add_prepended_conversation_to_memory_async( self, @@ -439,6 +438,7 @@ async def add_prepended_conversation_to_memory_async( # Create a deep copy of the prepended conversation to avoid modifying the original prepended_conversation = copy.deepcopy(prepended_conversation) + self.memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) for request in prepended_conversation: if should_convert and converter_configurations: @@ -450,7 +450,7 @@ async def add_prepended_conversation_to_memory_async( # and if not, this won't hurt anything piece.id = uuid4() - self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) + self.memory.add_message_to_memory(request=request) return prepended_conversation diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index ba91daee80..0a7c7cb93f 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -322,9 +322,10 @@ async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, ret prompt_metadata=_probe_metadata(), ) try: - target._memory.add_message_to_memory( - request=Message(message_pieces=[system_piece]), target_identifier=target.get_identifier() + target._memory.add_conversation_to_memory( + conversation_id=conversation_id, target_identifier=target.get_identifier() ) + target._memory.add_message_to_memory(request=Message(message_pieces=[system_piece])) except Exception as exc: logger.debug("System-prompt probe could not seed system message: %s", exc) return False @@ -408,9 +409,10 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie # Seed memory so the second send sees real prior history. try: - target._memory.add_message_to_memory( - request=Message(message_pieces=[first]), target_identifier=target.get_identifier() + target._memory.add_conversation_to_memory( + conversation_id=conversation_id, target_identifier=target.get_identifier() ) + target._memory.add_message_to_memory(request=Message(message_pieces=[first])) assistant_reply = MessagePiece( role="assistant", original_value="Got it.", @@ -418,9 +420,7 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie conversation_id=conversation_id, prompt_metadata=_probe_metadata(), ).to_message() - target._memory.add_message_to_memory( - request=assistant_reply, target_identifier=target.get_identifier() - ) + target._memory.add_message_to_memory(request=assistant_reply) except Exception as exc: logger.debug("Multi-turn probe could not seed conversation history: %s", exc) return False diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 3ff6416bb4..5d0cc93df7 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -327,6 +327,9 @@ def set_system_prompt( if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") + self._memory.add_conversation_to_memory( + conversation_id=conversation_id, target_identifier=self.get_identifier() + ) self._memory.add_message_to_memory( request=MessagePiece( role="system", @@ -335,7 +338,6 @@ def set_system_prompt( converted_value=system_prompt, labels=labels or {}, ).to_message(), - target_identifier=self.get_identifier(), ) def dispose_db_engine(self) -> None: diff --git a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py index 6d7a0767e4..9d5d5b0a35 100644 --- a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py +++ b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py @@ -402,6 +402,9 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: ) target_identifier = target.get_identifier() + target._memory.add_conversation_to_memory( + conversation_id=self._conversation_id, target_identifier=target_identifier + ) user_piece = MessagePiece( role="user", original_value=raw_user_path, @@ -437,12 +440,8 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: message=assistant_message, ) - await self._prompt_normalizer.hash_and_persist_message_async( - message=user_message, target_identifier=target_identifier - ) - await self._prompt_normalizer.hash_and_persist_message_async( - message=assistant_message, target_identifier=target_identifier - ) + await self._prompt_normalizer.hash_and_persist_message_async(message=user_message) + await self._prompt_normalizer.hash_and_persist_message_async(message=assistant_message) return assistant_message # ---- Wire helpers ------------------------------------------------------- diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 5e24e107a0..a4b69fd4a7 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -100,9 +100,11 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: message_pieces.append(message_piece) # This is post validation, so the message_pieces should be okay and normalized - self._memory.add_message_pieces_to_memory( - message_pieces=message_pieces, target_identifier=self.get_identifier() - ) + for conversation_id in {piece.conversation_id for piece in message_pieces if piece.conversation_id}: + self._memory.add_conversation_to_memory( + conversation_id=conversation_id, target_identifier=self.get_identifier() + ) + self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) return message_pieces def _validate_request(self, *, normalized_conversation: list[Message]) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 973a88234c..ba93bb4123 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -561,9 +561,7 @@ def test_add_message_pieces_to_memory_calls_validate(sqlite_instance: MemoryInte @pytest.mark.parametrize("bad_id", [None, "", " "]) -def test_add_message_pieces_to_memory_raises_when_conversation_id_missing( - sqlite_instance: MemoryInterface, bad_id -): +def test_add_message_pieces_to_memory_raises_when_conversation_id_missing(sqlite_instance: MemoryInterface, bad_id): piece = MessagePiece(role="user", original_value="hello", conversation_id=bad_id) with pytest.raises(ValueError, match="conversation_id"): sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) @@ -589,6 +587,51 @@ def test_add_message_pieces_to_memory_skips_not_in_memory_without_conversation_i assert sqlite_instance.get_message_pieces() == [] +def test_add_conversation_to_memory_records_target_for_plain_message_writes(sqlite_instance: MemoryInterface): + # Registering a conversation records its target once; subsequent message writes + # do not take a target, yet target-filtered reads still find the messages. + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + conversation_id = "conv-registered" + sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_id) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[MessagePiece(role="user", original_value="hi", conversation_id=conversation_id)] + ) + + metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + assert metadata is not None + assert metadata.target_identifier.hash == target_id.hash + + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + value=target_id.hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].conversation_id == conversation_id + + +def test_message_writes_without_registration_create_no_conversation_row(sqlite_instance: MemoryInterface): + # Message writes no longer touch the Conversations table; conversation metadata + # exists only when a conversation is explicitly registered. + conversation_id = "conv-unregistered" + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[MessagePiece(role="user", original_value="hi", conversation_id=conversation_id)] + ) + + assert sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) is None + # The messages themselves still persist. + assert len(sqlite_instance.get_message_pieces(conversation_id=conversation_id)) == 1 + + def test_add_message_pieces_to_memory_updates_sequence( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): @@ -1375,6 +1418,7 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, ) + sqlite_instance.add_conversation_to_memory(conversation_id="conv-openai", target_identifier=target_id_1) sqlite_instance.add_message_pieces_to_memory( message_pieces=[ MessagePiece( @@ -1383,8 +1427,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI conversation_id="conv-openai", ), ], - target_identifier=target_id_1, ) + sqlite_instance.add_conversation_to_memory(conversation_id="conv-azure", target_identifier=target_id_2) sqlite_instance.add_message_pieces_to_memory( message_pieces=[ MessagePiece( @@ -1393,7 +1437,6 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI conversation_id="conv-azure", ), ], - target_identifier=target_id_2, ) # Filter by target hash diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index ddcd509718..64c712861c 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -206,9 +206,10 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): converter_identifiers=converter_identifiers, ) - memory_interface.add_message_pieces_to_memory( - message_pieces=[piece], target_identifier=target.get_identifier() + memory_interface.add_conversation_to_memory( + conversation_id=specific_conversation_id, target_identifier=target.get_identifier() ) + memory_interface.add_message_pieces_to_memory(message_pieces=[piece]) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index c22f6a0bca..2be2fb09d4 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -528,9 +528,10 @@ def test_get_memories_with_json_properties(sqlite_instance): converter_identifiers=converter_identifiers, ) - sqlite_instance.add_message_pieces_to_memory( - message_pieces=[piece], target_identifier=target.get_identifier() + sqlite_instance.add_conversation_to_memory( + conversation_id=specific_conversation_id, target_identifier=target.get_identifier() ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) @@ -557,10 +558,10 @@ def test_get_memories_with_json_properties(sqlite_instance): assert labels["normalizer_id"] == "id1" -def test_capture_conversation_none_target_does_not_clobber(sqlite_instance): - # A conversation is held with a single target. The request piece records the - # target; a later write for the same conversation that has no target (e.g. a - # response or branched copy) must NOT overwrite the recorded target with None. +def test_register_conversation_none_target_does_not_clobber(sqlite_instance): + # A conversation is held with a single target. Registering it records the + # target; a later registration for the same conversation with no target (e.g. + # a branched copy whose source had no metadata) must NOT overwrite it with None. conversation_id = "conv-none-clobber" target = TextTarget() @@ -570,9 +571,10 @@ def test_capture_conversation_none_target_does_not_clobber(sqlite_instance): sequence=1, original_value="hello", ) - sqlite_instance.add_message_pieces_to_memory( - message_pieces=[request_piece], target_identifier=target.get_identifier() + sqlite_instance.add_conversation_to_memory( + conversation_id=conversation_id, target_identifier=target.get_identifier() ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[request_piece]) response_piece = MessagePiece( conversation_id=conversation_id, @@ -580,7 +582,8 @@ def test_capture_conversation_none_target_does_not_clobber(sqlite_instance): sequence=2, original_value="world", ) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece], target_identifier=None) + sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=None) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece]) metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) assert metadata is not None From 63594c7d63e74861ed4bff71309b593cb8d04472 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 9 Jun 2026 10:25:19 -0700 Subject: [PATCH 06/12] Rename _add_message_pieces_to_storage to _add_message_pieces_to_memory; fix video scorer test Rename the backend persistence hook to match its sibling _add_embeddings_to_memory and avoid overloading 'storage', which in this layer denotes file/blob IO (StorageIO/DiskStorageIO/AzureBlobStorageIO), not the DB. Updates the abstract def, its call site, both backend implementations, and the template-method docstring (which also no longer claims Conversations rows are captured here). Also set conversation_id on the video scorer test fixture so frame/audio pieces inherit a non-empty conversation_id, matching the caller-owned-id invariant (same fix already applied to the audio scorer fixture). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/azure_sql_memory.py | 2 +- pyrit/memory/memory_interface.py | 12 ++++++------ pyrit/memory/sqlite_memory.py | 5 ++--- tests/unit/score/test_video_scorer.py | 1 + 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index f7c21ed651..193e631a49 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -695,7 +695,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def _add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Persist already-validated message pieces to the Azure SQL store. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 410db0b273..5a9e126040 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -386,7 +386,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] conversation is created. This is a template method: subclasses implement only the backend-specific - ``_add_message_pieces_to_storage`` and inherit the filtering and validation + ``_add_message_pieces_to_memory`` and inherit the filtering and validation steps so no subclass can forget to run them. Args: @@ -396,17 +396,17 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] if not pieces_to_insert: return self._validate_persistable_conversation_ids(message_pieces=pieces_to_insert) - self._add_message_pieces_to_storage(message_pieces=pieces_to_insert) + self._add_message_pieces_to_memory(message_pieces=pieces_to_insert) @abc.abstractmethod - def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def _add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Persist already-validated message pieces to the backing store. Called by ``add_message_pieces_to_memory`` after ``not_in_memory`` pieces are - filtered out, conversation_ids are validated, and the ``Conversations`` rows are - captured. Implementations only translate the pieces into storage rows and insert - them; they must not re-filter or re-validate. + filtered out and conversation_ids are validated. Implementations only translate + the pieces into storage rows and insert them; they must not re-filter or + re-validate. Args: message_pieces (Sequence[MessagePiece]): Persistable pieces (none flagged diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3fcb28326c..5bc805c3f3 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -302,7 +302,7 @@ def _get_condition_json_array_match( combined = joiner.join(conditions) return text(f"({combined})").bindparams(**bindparams_dict) - def _add_message_pieces_to_storage(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def _add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Persist already-validated message pieces to the SQLite store. @@ -364,8 +364,7 @@ def _query_entries( ) elif model_class == AttackResultEntry: query = query.options( - joinedload(AttackResultEntry.last_response) - .joinedload(PromptMemoryEntry.scores), + joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), joinedload(AttackResultEntry.last_score), ) if conditions is not None: diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index e60d62bc56..7e9632093f 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -52,6 +52,7 @@ def video_converter_sample_video(tmp_path, patch_central_database): converted_value=video_path, original_value_data_type="video_path", converted_value_data_type="video_path", + conversation_id=str(uuid.uuid4()), ) message_piece.id = uuid.uuid4() yield message_piece From 5aac569cb296c16f84c4ffd01298798aa40e48bf Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 9 Jun 2026 12:00:17 -0700 Subject: [PATCH 07/12] Add coverage for conversation migration, upsert, and attack_identifier deprecation shims Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/backend/test_attack_service.py | 7 + .../test_interface_prompts.py | 34 +++++ tests/unit/memory/test_migration.py | 136 ++++++++++++++++++ .../test_prompt_normalizer.py | 35 +++++ .../test_openai_realtime_streaming_session.py | 29 ++++ .../target/test_prompt_target.py | 13 ++ 6 files changed, 254 insertions(+) diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index a268d0b744..0de9d88d38 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -2405,6 +2405,13 @@ def test_duplicate_conversation_remaps_assistant_to_simulated(self, attack_servi assert dup_piece.role == "simulated_assistant" + async def test_store_prepended_messages_noop_when_empty(self, attack_service, mock_memory): + """Empty prepended list should be a no-op: no conversation row and no piece writes.""" + await attack_service._store_prepended_messages_async(conversation_id="conv-1", prepended=[]) + + mock_memory.add_conversation_to_memory.assert_not_called() + mock_memory.add_message_pieces_to_memory.assert_not_called() + class TestAddMessageGuards: """Tests for target-mismatch and operator-mismatch guards in add_message_async.""" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index ba93bb4123..2c05701823 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -632,6 +632,40 @@ def test_message_writes_without_registration_create_no_conversation_row(sqlite_i assert len(sqlite_instance.get_message_pieces(conversation_id=conversation_id)) == 1 +def test_add_conversation_to_memory_updates_existing_target_on_reregister(sqlite_instance: MemoryInterface): + # Re-registering a conversation with a new non-null target overwrites the previously + # recorded one. (A None re-registration never clobbers -- covered separately.) + conversation_id = "conv-retarget" + target_a = ComponentIdentifier( + class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"endpoint": "a"} + ) + target_b = ComponentIdentifier( + class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"endpoint": "b"} + ) + sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_a) + sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_b) + + metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + assert metadata is not None + assert metadata.target_identifier.hash == target_b.hash + + +def test_upsert_conversation_rolls_back_and_reraises_on_db_error(sqlite_instance: MemoryInterface): + # A DB failure during the upsert rolls back the session and propagates the error + # rather than leaving a half-written Conversations row. + from sqlalchemy.exc import SQLAlchemyError + + session = MagicMock() + session.get.side_effect = SQLAlchemyError("boom") + + with patch.object(sqlite_instance, "get_session", return_value=session): + with pytest.raises(SQLAlchemyError, match="boom"): + sqlite_instance._upsert_conversation(conversation_id="conv-fail", target_identifier=None) + + session.rollback.assert_called_once() + session.commit.assert_not_called() + + def test_add_message_pieces_to_memory_updates_sequence( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index f2562a2fcc..9aed7ce4d6 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -585,3 +585,139 @@ def test_check_schema_migrations_not_silent_prints_output(capsys): assert f"{ALEMBIC_OUTPUT_PREFIX}No new upgrade operations detected." in captured.out finally: engine.dispose() + + +# ============================================================================= +# Backfill tests for the Conversations table migration (b2f4c6a8d1e3) +# ============================================================================= + + +_CONVERSATIONS_REV = "b2f4c6a8d1e3" +_CONVERSATIONS_PREV_REV = "9c8b7a6d5e4f" + +_TARGET_A = '{"name": "target-a"}' +_TARGET_B = '{"name": "target-b"}' + + +def _seed_pre_conversations_prompt_piece(connection, *, piece_id, conversation_id, sequence, target_identifier): + """Insert a PromptMemoryEntry row at the pre-Conversations revision.""" + connection.execute( + text( + 'INSERT INTO "PromptMemoryEntries" ' + "(id, role, conversation_id, sequence, timestamp, labels, prompt_metadata, " + "prompt_target_identifier, attack_identifier, original_value_data_type, " + "original_value, converted_value_data_type, original_prompt_id) " + "VALUES (:id, 'user', :conv, :seq, '2026-05-20', '{}', '{}', " + ":target, '{}', 'text', 'hello', 'text', :id)" + ), + {"id": piece_id, "conv": conversation_id, "seq": sequence, "target": target_identifier}, + ) + + +def test_conversations_migration_script_metadata(): + """The Conversations migration declares the expected revision chain.""" + from pyrit.memory.alembic.versions import b2f4c6a8d1e3_add_conversations_table as mig + + assert mig.revision == _CONVERSATIONS_REV + assert mig.down_revision == _CONVERSATIONS_PREV_REV + assert mig.branch_labels is None + assert mig.depends_on is None + + +def test_conversations_backfill_populates_targets_and_handles_conflicts(caplog): + """Upgrading to the Conversations revision backfills one row per conversation_id: + the target comes from PromptMemoryEntries (first non-null wins on conflict), + attack-only conversations get a null placeholder, and the per-row identifier + columns are dropped.""" + import logging + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "conversations-backfill.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + config = _config_for(connection) + command.upgrade(config, _CONVERSATIONS_PREV_REV) + + # A conversation whose two pieces share one target. + _seed_pre_conversations_prompt_piece( + connection, + piece_id=str(uuid.uuid4()), + conversation_id="conv-keep", + sequence=0, + target_identifier=_TARGET_A, + ) + _seed_pre_conversations_prompt_piece( + connection, + piece_id=str(uuid.uuid4()), + conversation_id="conv-keep", + sequence=1, + target_identifier=_TARGET_A, + ) + # A conversation with two distinct non-null targets -> first wins + warning. + _seed_pre_conversations_prompt_piece( + connection, + piece_id=str(uuid.uuid4()), + conversation_id="conv-conflict", + sequence=0, + target_identifier=_TARGET_A, + ) + _seed_pre_conversations_prompt_piece( + connection, + piece_id=str(uuid.uuid4()), + conversation_id="conv-conflict", + sequence=1, + target_identifier=_TARGET_B, + ) + # A conversation referenced only by an AttackResultEntry (no prompt rows). + _seed_pre_migration_attack_result( + connection, attack_id=str(uuid.uuid4()), conversation_id="conv-attack-only" + ) + + with caplog.at_level(logging.WARNING): + command.upgrade(config, _CONVERSATIONS_REV) + + rows = connection.execute( + text('SELECT conversation_id, target_identifier FROM "Conversations" ORDER BY conversation_id') + ).fetchall() + prompt_cols = {c["name"] for c in inspect(connection).get_columns("PromptMemoryEntries")} + + targets_by_conv = {r[0]: r[1] for r in rows} + + assert set(targets_by_conv) == {"conv-keep", "conv-conflict", "conv-attack-only"} + assert targets_by_conv["conv-keep"] == _TARGET_A + assert targets_by_conv["conv-conflict"] == _TARGET_A # first non-null wins + assert targets_by_conv["conv-attack-only"] is None # placeholder for attack-only conversation + + # The conflicting targets produced a warning. + assert any("multiple distinct" in r.message for r in caplog.records) + + # The per-row identifier columns are gone. + assert "prompt_target_identifier" not in prompt_cols + assert "attack_identifier" not in prompt_cols + finally: + engine.dispose() + + +def test_conversations_migration_downgrade_restores_columns(): + """Downgrading drops the Conversations table and re-adds the per-row identifier columns.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "conversations-downgrade.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + config = _config_for(connection) + command.upgrade(config, _CONVERSATIONS_REV) + + assert "Conversations" in set(inspect(connection).get_table_names()) + cols_up = {c["name"] for c in inspect(connection).get_columns("PromptMemoryEntries")} + assert "prompt_target_identifier" not in cols_up + + command.downgrade(config, _CONVERSATIONS_PREV_REV) + + assert "Conversations" not in set(inspect(connection).get_table_names()) + cols_down = {c["name"] for c in inspect(connection).get_columns("PromptMemoryEntries")} + assert "prompt_target_identifier" in cols_down + assert "attack_identifier" in cols_down + finally: + engine.dispose() diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 55a4a4f818..10501956cc 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -152,6 +152,24 @@ async def test_send_prompt_async_labels_emit_deprecation_warning(mock_memory_ins mock_deprecation.assert_called_once() +async def test_send_prompt_async_attack_identifier_emits_deprecation_warning(mock_memory_instance, seed_group): + prompt_target = MagicMock() + prompt_target.send_prompt_async = AsyncMock( + return_value=[MessagePiece(role="assistant", original_value="ok", conversation_id="conv-1").to_message()] + ) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") + + normalizer = PromptNormalizer() + message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") + + with patch("pyrit.prompt_normalizer.prompt_normalizer.print_deprecation_message") as mock_deprecation: + await normalizer.send_prompt_async( + message=message, target=prompt_target, attack_identifier=get_mock_attack_identifier("TestAttack") + ) + + mock_deprecation.assert_called_once() + + async def test_send_prompt_async_empty_response_exception_handled(mock_memory_instance, seed_group): # Use MagicMock with send_prompt_async as AsyncMock to avoid coroutine warnings on other methods prompt_target = MagicMock() @@ -630,6 +648,23 @@ async def test_add_prepended_conversation_to_memory(mock_memory_instance): mock_memory_instance.add_message_to_memory.assert_called_once() +async def test_add_prepended_conversation_to_memory_attack_identifier_emits_deprecation_warning(mock_memory_instance): + normalizer = PromptNormalizer() + + piece = MessagePiece(role="user", original_value="prepended text", conversation_id="old-id") + message = Message(message_pieces=[piece]) + + with patch("pyrit.prompt_normalizer.prompt_normalizer.print_deprecation_message") as mock_deprecation: + await normalizer.add_prepended_conversation_to_memory_async( + conversation_id="test-conv-id", + should_convert=False, + prepended_conversation=[message], + attack_identifier=get_mock_attack_identifier("TestAttack"), + ) + + mock_deprecation.assert_called_once() + + _AUDIO_SAMPLE_RATE_HZ = 24000 _AUDIO_NUM_CHANNELS = 1 _AUDIO_SAMPLE_WIDTH_BYTES = 2 diff --git a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py index 451e353ece..9901158811 100644 --- a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py +++ b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py @@ -736,6 +736,35 @@ def _fake_session_ctor(**kwargs): assert captured["persist_prepended_conversation"] is False +@patch.dict("os.environ", _CLEAN_ENV) +def test_open_streaming_session_attack_identifier_emits_deprecation_warning(sqlite_instance): + """Passing the deprecated ``attack_identifier`` kwarg emits a deprecation message.""" + from pyrit.prompt_target import RealtimeTarget + + target = RealtimeTarget(api_key="k", endpoint="wss://test_url", model_name="test") + normalizer = _build_normalizer() + + async def _empty(): + if False: + yield b"" + + with ( + patch( + "pyrit.prompt_target.openai.openai_realtime_target._OpenAIRealtimeStreamingSession", + side_effect=lambda **kwargs: MagicMock(name="session"), + ), + patch("pyrit.prompt_target.openai.openai_realtime_target.print_deprecation_message") as mock_deprecation, + ): + target.open_streaming_session( + audio_chunks=_empty(), + prompt_normalizer=normalizer, + conversation_id="conv-X", + attack_identifier=MagicMock(name="attack_identifier"), + ) + + mock_deprecation.assert_called_once() + + # --------------------------------------------------------------------------- # 12. Direct unit tests for the _trim_snapshot_to_speech helper # --------------------------------------------------------------------------- diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 161070de85..635e4021bd 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -69,6 +69,19 @@ def test_set_system_prompt(azure_openai_target: OpenAIChatTarget, mock_attack_st assert chats[0].converted_value == "system prompt" +def test_set_system_prompt_attack_identifier_emits_deprecation_warning( + azure_openai_target: OpenAIChatTarget, mock_attack_strategy: AttackStrategy +): + with patch("pyrit.prompt_target.common.prompt_target.print_deprecation_message") as mock_deprecation: + azure_openai_target.set_system_prompt( + system_prompt="system prompt", + conversation_id="1", + attack_identifier=mock_attack_strategy.get_identifier(), + ) + + mock_deprecation.assert_called_once() + + async def test_set_system_prompt_adds_memory( azure_openai_target: OpenAIChatTarget, mock_attack_strategy: AttackStrategy ): From 4ae9bbcb763c13a17ec97d5c3a2e86641ccaf3ca Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 10 Jun 2026 10:29:47 -0700 Subject: [PATCH 08/12] Use Conversation object param and insert-only conversation registration Address review: add_conversation_to_memory takes a Conversation; replace upsert with insert-only (no-op on identical, raise on conflicting target). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 11 ++- .../attack/component/conversation_manager.py | 13 ++- .../multi_turn/multi_turn_attack_strategy.py | 6 +- .../executor/attack/multi_turn/red_teaming.py | 7 +- .../attack/multi_turn/tree_of_attacks.py | 5 +- pyrit/memory/memory_interface.py | 98 +++++++++++-------- pyrit/prompt_normalizer/prompt_normalizer.py | 9 +- .../common/discover_target_capabilities.py | 6 +- pyrit/prompt_target/common/prompt_target.py | 4 +- .../_openai_realtime_streaming_session.py | 4 +- pyrit/prompt_target/text_target.py | 4 +- tests/unit/backend/test_attack_service.py | 1 + .../test_interface_prompts.py | 58 ++++++++--- tests/unit/memory/test_azure_sql_memory.py | 4 +- tests/unit/memory/test_sqlite_memory.py | 10 +- .../test_prompt_normalizer.py | 5 + 16 files changed, 163 insertions(+), 82 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index fe3b8f527d..cabff15cfe 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -55,6 +55,7 @@ AttackOutcome, AttackResult, ComponentIdentifier, + Conversation, ConversationStats, ConversationType, MessagePiece, @@ -876,7 +877,7 @@ def _duplicate_conversation_up_to( if all_pieces: self._memory.add_conversation_to_memory( - conversation_id=new_conversation_id, target_identifier=target_identifier + conversation=Conversation(conversation_id=new_conversation_id, target_identifier=target_identifier) ) self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces)) @@ -971,7 +972,9 @@ async def _store_prepended_messages_async( """Store prepended conversation messages in memory.""" if not prepended: return - self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + self._memory.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) for seq, msg in enumerate(prepended): for p in msg.pieces: piece = request_piece_to_pyrit_message_piece( @@ -1031,7 +1034,9 @@ async def _store_message_only_async( ) -> None: """Store message without sending (send=False).""" await self._persist_base64_pieces_async(request) - self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + self._memory.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) for p in request.pieces: piece = request_piece_to_pyrit_message_piece( piece=p, diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 30b9305dac..0cd0c85d05 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -14,7 +14,14 @@ ) from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer -from pyrit.models import ChatMessageRole, ComponentIdentifier, Message, MessagePiece, Score +from pyrit.models import ( + ChatMessageRole, + ComponentIdentifier, + Conversation, + Message, + MessagePiece, + Score, +) from pyrit.prompt_normalizer.prompt_converter_configuration import ( PromptConverterConfiguration, ) @@ -465,7 +472,9 @@ async def add_prepended_conversation_to_memory_async( if not valid_messages: return 0 - self._memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + self._memory.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) # Get roles that should have converters applied apply_to_roles = ( diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 72db4d75a1..6ec93a7aa5 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -17,7 +17,7 @@ AttackStrategyResultT, ) from pyrit.memory import CentralMemory -from pyrit.models import ConversationReference, ConversationType +from pyrit.models import Conversation, ConversationReference, ConversationType from pyrit.prompt_target import CapabilityName if TYPE_CHECKING: @@ -142,7 +142,9 @@ def _rotate_conversation_for_single_turn_target( if system_messages: new_conversation_id, pieces = memory.duplicate_messages(messages=system_messages) memory.add_conversation_to_memory( - conversation_id=new_conversation_id, target_identifier=self._objective_target.get_identifier() + conversation=Conversation( + conversation_id=new_conversation_id, target_identifier=self._objective_target.get_identifier() + ) ) memory.add_message_pieces_to_memory(message_pieces=pieces) context.session.conversation_id = new_conversation_id diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 4c47fced33..4065368730 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -30,6 +30,7 @@ from pyrit.models import ( AttackOutcome, AttackResult, + Conversation, ConversationReference, ConversationType, Message, @@ -273,8 +274,10 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: ) self._memory.add_conversation_to_memory( - conversation_id=context.session.adversarial_chat_conversation_id, - target_identifier=self._adversarial_chat.get_identifier(), + conversation=Conversation( + conversation_id=context.session.adversarial_chat_conversation_id, + target_identifier=self._adversarial_chat.get_identifier(), + ) ) for msg in adversarial_messages: self._memory.add_message_to_memory(request=msg) diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 46909a8384..32d8567aa3 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -43,6 +43,7 @@ AttackOutcome, AttackResult, ComponentIdentifier, + Conversation, ConversationReference, ConversationType, Message, @@ -822,7 +823,9 @@ def duplicate(self) -> "_TreeOfAttacksNode": if system_messages: new_id, pieces = self._memory.duplicate_messages(messages=system_messages) self._memory.add_conversation_to_memory( - conversation_id=new_id, target_identifier=self._objective_target.get_identifier() + conversation=Conversation( + conversation_id=new_id, target_identifier=self._objective_target.get_identifier() + ) ) self._memory.add_message_pieces_to_memory(message_pieces=pieces) duplicate_node.objective_target_conversation_id = new_id diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ff209c7286..9078d36768 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -39,7 +39,6 @@ ) from pyrit.models import ( AttackResult, - ComponentIdentifier, Conversation, ConversationStats, IdentifierFilter, @@ -334,38 +333,42 @@ def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dic @abc.abstractmethod def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> Any: """ - Return a condition for filtering seed prompt entries based on prompt metadata. - - Args: - metadata (dict[str, str | int]): A free-form dictionary for tagging prompts with custom metadata. - This includes information that is useful for the specific target you're probing, such as encoding data. + Return a condition for filtering seed prompt entries based on prompt metadata. + s + Args: + metadata (dict[str, str | int]): A free-form dictionary for tagging prompts with custom metadata. + This includes information that is useful for the specific target you're probing, such as encoding data. Returns: - Any: A SQLAlchemy condition for filtering memory entries based on prompt metadata. + Any: A SQLAlchemy condition for filtering memory entries based on prompt metadata. """ - def add_conversation_to_memory( - self, *, conversation_id: str, target_identifier: ComponentIdentifier | None = None - ) -> None: + def add_conversation_to_memory(self, *, conversation: Conversation) -> None: """ Register a conversation in memory, recording its conversation-scoped metadata. - A conversation is a first-class entity held with a single target. Call this once - when a conversation is created (before, or independently of, adding its messages) - to record the target it is held with. Message writes (``add_message_to_memory`` / - ``add_message_pieces_to_memory``) deliberately do not take a target, so that - conversation ownership is expressed in a single place rather than threaded through - every write. + A conversation is a first-class entity held with a single target. Build a + ``Conversation`` when it is created and call this once (before, or independently + of, adding its messages) to record the target it is held with. Message writes + (``add_message_to_memory`` / ``add_message_pieces_to_memory``) deliberately do + not take a target, so that conversation ownership is expressed in a single place + rather than threaded through every write. - Registration is idempotent: a non-``None`` ``target_identifier`` is recorded, and - a ``None`` value never overwrites a target already recorded for the conversation. + Registration is idempotent only for an identical conversation: re-registering the + same ``conversation_id`` with the same target is a no-op (so repeated per-turn + registration is safe). Re-registering an existing ``conversation_id`` with a + different target is a conflict and raises ``ValueError`` -- a conversation is held + with exactly one target and is never re-targeted. Args: - conversation_id (str): The caller-owned conversation identifier. - target_identifier (ComponentIdentifier | None): The target the conversation is - held with, if known. + conversation (Conversation): The conversation metadata to record, carrying the + ``conversation_id`` and the target it is held with (if known). + + Raises: + ValueError: If ``conversation_id`` is empty, or if a conversation with the same + id already exists with a different target. """ - self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) + self._insert_conversation(conversation=conversation) def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ @@ -433,40 +436,45 @@ def _validate_persistable_conversation_ids(*, message_pieces: Sequence[MessagePi "the caller before a piece is persisted; the memory layer does not generate one." ) - def _upsert_conversation(self, *, conversation_id: str, target_identifier: ComponentIdentifier | None) -> None: + def _insert_conversation(self, *, conversation: Conversation) -> None: """ - Insert or update the ``Conversations`` row for ``conversation_id``. + Insert the ``Conversations`` row for a conversation, never updating an existing one. - A non-``None`` ``target_identifier`` is written; a ``None`` value never - overwrites a target already recorded for the conversation (so re-registration - and copy/duplicate flows cannot clobber it). + A conversation is held with exactly one target, so this is insert-only with + idempotent-on-identical semantics: if no row exists it is inserted; if a row + already exists with the same target it is left untouched; if a row exists with a + different target it is a conflict and raises. Args: - conversation_id (str): The conversation to record. - target_identifier (ComponentIdentifier | None): The target the conversation - is held with, if known. + conversation (Conversation): The conversation metadata to record. Raises: - ValueError: If ``conversation_id`` is empty. - SQLAlchemyError: If the upsert fails. + ValueError: If ``conversation.conversation_id`` is empty, or if a conversation + with the same id already exists with a different target. + SQLAlchemyError: If the insert fails. """ - if not conversation_id: + if not conversation.conversation_id: raise ValueError("Cannot register a conversation without a conversation_id.") - entry = ConversationEntry( - conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) - ) + entry = ConversationEntry(conversation=conversation) with closing(self.get_session()) as session: try: - existing = session.get(ConversationEntry, conversation_id) + existing = session.get(ConversationEntry, conversation.conversation_id) if existing is None: session.add(entry) - elif target_identifier is not None: - existing.target_identifier = entry.target_identifier - existing.pyrit_version = entry.pyrit_version + elif ( + entry.target_identifier is not None + and existing.target_identifier is not None + and existing.target_identifier != entry.target_identifier + ): + raise ValueError( + f"Conversation {conversation.conversation_id} is already registered with a different " + f"target ({existing.target_identifier!r}); a conversation is held with exactly one " + f"target and cannot be re-registered with {entry.target_identifier!r}." + ) session.commit() except SQLAlchemyError as e: session.rollback() - logger.exception(f"Error upserting conversation {conversation_id}: {e}") + logger.exception(f"Error registering conversation {conversation.conversation_id}: {e}") raise @abc.abstractmethod @@ -1234,7 +1242,9 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) if all_pieces: - self.add_conversation_to_memory(conversation_id=new_conversation_id, target_identifier=source_target) + self.add_conversation_to_memory( + conversation=Conversation(conversation_id=new_conversation_id, target_identifier=source_target) + ) self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id @@ -1271,7 +1281,9 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) if all_pieces: - self.add_conversation_to_memory(conversation_id=new_conversation_id, target_identifier=source_target) + self.add_conversation_to_memory( + conversation=Conversation(conversation_id=new_conversation_id, target_identifier=source_target) + ) self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 30ccc2a698..56c7bcfa3c 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -22,6 +22,7 @@ from pyrit.memory import CentralMemory, MemoryInterface, set_message_piece_sha256_async from pyrit.models import ( ComponentIdentifier, + Conversation, Message, MessagePiece, construct_response_from_request, @@ -119,7 +120,9 @@ async def send_prompt_async( request = copy.deepcopy(message) conversation_id = conversation_id if conversation_id else str(uuid4()) target_identifier = target.get_identifier() - self.memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + self.memory.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) for piece in request.message_pieces: piece.conversation_id = conversation_id @@ -438,7 +441,9 @@ async def add_prepended_conversation_to_memory_async( # Create a deep copy of the prepended conversation to avoid modifying the original prepended_conversation = copy.deepcopy(prepended_conversation) - self.memory.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_identifier) + self.memory.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) for request in prepended_conversation: if should_convert and converter_configurations: diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 0a7c7cb93f..9c227ba02d 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -45,7 +45,7 @@ from pathlib import Path from pyrit.common.path import DATASETS_PATH -from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.models import Conversation, Message, MessagePiece, PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import ( CapabilityName, @@ -323,7 +323,7 @@ async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, ret ) try: target._memory.add_conversation_to_memory( - conversation_id=conversation_id, target_identifier=target.get_identifier() + conversation=Conversation(conversation_id=conversation_id, target_identifier=target.get_identifier()) ) target._memory.add_message_to_memory(request=Message(message_pieces=[system_piece])) except Exception as exc: @@ -410,7 +410,7 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie # Seed memory so the second send sees real prior history. try: target._memory.add_conversation_to_memory( - conversation_id=conversation_id, target_identifier=target.get_identifier() + conversation=Conversation(conversation_id=conversation_id, target_identifier=target.get_identifier()) ) target._memory.add_message_to_memory(request=Message(message_pieces=[first])) assistant_reply = MessagePiece( diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 5d0cc93df7..9c0cf464e3 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -7,7 +7,7 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import ComponentIdentifier, Identifiable, Message, MessagePiece +from pyrit.models import ComponentIdentifier, Conversation, Identifiable, Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -328,7 +328,7 @@ def set_system_prompt( raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") self._memory.add_conversation_to_memory( - conversation_id=conversation_id, target_identifier=self.get_identifier() + conversation=Conversation(conversation_id=conversation_id, target_identifier=self.get_identifier()) ) self._memory.add_message_to_memory( request=MessagePiece( diff --git a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py index 9d5d5b0a35..1581d445a6 100644 --- a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py +++ b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from pyrit.models import Message, MessagePiece +from pyrit.models import Conversation, Message, MessagePiece from pyrit.prompt_target.common.realtime_audio import ( STREAMING_INTERRUPTED_KEY, RealtimeTargetResult, @@ -403,7 +403,7 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: target_identifier = target.get_identifier() target._memory.add_conversation_to_memory( - conversation_id=self._conversation_id, target_identifier=target_identifier + conversation=Conversation(conversation_id=self._conversation_id, target_identifier=target_identifier) ) user_piece = MessagePiece( role="user", diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index a4b69fd4a7..2601b0c6d0 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -8,7 +8,7 @@ from typing import IO from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import Message, MessagePiece +from pyrit.models import Conversation, Message, MessagePiece from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -102,7 +102,7 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: # This is post validation, so the message_pieces should be okay and normalized for conversation_id in {piece.conversation_id for piece in message_pieces if piece.conversation_id}: self._memory.add_conversation_to_memory( - conversation_id=conversation_id, target_identifier=self.get_identifier() + conversation=Conversation(conversation_id=conversation_id, target_identifier=self.get_identifier()) ) self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) return message_pieces diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 0de9d88d38..be251ade7f 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -38,6 +38,7 @@ def mock_memory(): memory.get_conversation.return_value = [] memory.get_message_pieces.return_value = [] memory.get_conversation_stats.return_value = {} + memory.get_conversation_metadata.return_value = None return memory diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 2c05701823..62c485c99d 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -16,6 +16,7 @@ from pyrit.models import ( AttackResult, ComponentIdentifier, + Conversation, IdentifierFilter, IdentifierType, Message, @@ -596,7 +597,9 @@ def test_add_conversation_to_memory_records_target_for_plain_message_writes(sqli params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, ) conversation_id = "conv-registered" - sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_id) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_id) + ) sqlite_instance.add_message_pieces_to_memory( message_pieces=[MessagePiece(role="user", original_value="hi", conversation_id=conversation_id)] ) @@ -632,9 +635,30 @@ def test_message_writes_without_registration_create_no_conversation_row(sqlite_i assert len(sqlite_instance.get_message_pieces(conversation_id=conversation_id)) == 1 -def test_add_conversation_to_memory_updates_existing_target_on_reregister(sqlite_instance: MemoryInterface): - # Re-registering a conversation with a new non-null target overwrites the previously - # recorded one. (A None re-registration never clobbers -- covered separately.) +def test_add_conversation_to_memory_same_target_reregister_is_noop(sqlite_instance: MemoryInterface): + # A conversation is held with exactly one target. Re-registering the same + # conversation with the same target is idempotent (no error, no change) so that + # per-turn registration during a multi-turn conversation is safe. + conversation_id = "conv-reregister-same" + target = ComponentIdentifier( + class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"endpoint": "a"} + ) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target) + ) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target) + ) + + metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + assert metadata is not None + assert metadata.target_identifier.hash == target.hash + + +def test_add_conversation_to_memory_different_target_reregister_raises(sqlite_instance: MemoryInterface): + # A conversation is held with exactly one target, so re-registering an existing + # conversation_id with a different target is a conflict and must raise rather than + # silently re-targeting the conversation. conversation_id = "conv-retarget" target_a = ComponentIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"endpoint": "a"} @@ -642,16 +666,22 @@ def test_add_conversation_to_memory_updates_existing_target_on_reregister(sqlite target_b = ComponentIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"endpoint": "b"} ) - sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_a) - sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=target_b) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_a) + ) + with pytest.raises(ValueError, match="already registered with a different target"): + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_b) + ) + # The originally recorded target is left untouched. metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) assert metadata is not None - assert metadata.target_identifier.hash == target_b.hash + assert metadata.target_identifier.hash == target_a.hash -def test_upsert_conversation_rolls_back_and_reraises_on_db_error(sqlite_instance: MemoryInterface): - # A DB failure during the upsert rolls back the session and propagates the error +def test_insert_conversation_rolls_back_and_reraises_on_db_error(sqlite_instance: MemoryInterface): + # A DB failure during registration rolls back the session and propagates the error # rather than leaving a half-written Conversations row. from sqlalchemy.exc import SQLAlchemyError @@ -660,7 +690,7 @@ def test_upsert_conversation_rolls_back_and_reraises_on_db_error(sqlite_instance with patch.object(sqlite_instance, "get_session", return_value=session): with pytest.raises(SQLAlchemyError, match="boom"): - sqlite_instance._upsert_conversation(conversation_id="conv-fail", target_identifier=None) + sqlite_instance._insert_conversation(conversation=Conversation(conversation_id="conv-fail")) session.rollback.assert_called_once() session.commit.assert_not_called() @@ -1452,7 +1482,9 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, ) - sqlite_instance.add_conversation_to_memory(conversation_id="conv-openai", target_identifier=target_id_1) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id="conv-openai", target_identifier=target_id_1) + ) sqlite_instance.add_message_pieces_to_memory( message_pieces=[ MessagePiece( @@ -1462,7 +1494,9 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI ), ], ) - sqlite_instance.add_conversation_to_memory(conversation_id="conv-azure", target_identifier=target_id_2) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id="conv-azure", target_identifier=target_id_2) + ) sqlite_instance.add_message_pieces_to_memory( message_pieces=[ MessagePiece( diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 64c712861c..7289285218 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -11,7 +11,7 @@ from sqlalchemy import inspect, text from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry -from pyrit.models import MessagePiece +from pyrit.models import Conversation, MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget from unit.mocks import get_azure_sql_memory, get_sample_conversation_entries @@ -207,7 +207,7 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): ) memory_interface.add_conversation_to_memory( - conversation_id=specific_conversation_id, target_identifier=target.get_identifier() + conversation=Conversation(conversation_id=specific_conversation_id, target_identifier=target.get_identifier()) ) memory_interface.add_message_pieces_to_memory(message_pieces=[piece]) diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 2be2fb09d4..1e88fe9927 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -19,7 +19,7 @@ from pyrit.memory.alembic.versions.ab8f2c1a9d07_pre_alembic_release_schema import INITIAL_METADATA from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry from pyrit.memory.migration import run_schema_migrations -from pyrit.models import MessagePiece +from pyrit.models import Conversation, MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget from unit.mocks import get_sample_conversation_entries @@ -529,7 +529,7 @@ def test_get_memories_with_json_properties(sqlite_instance): ) sqlite_instance.add_conversation_to_memory( - conversation_id=specific_conversation_id, target_identifier=target.get_identifier() + conversation=Conversation(conversation_id=specific_conversation_id, target_identifier=target.get_identifier()) ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) @@ -572,7 +572,7 @@ def test_register_conversation_none_target_does_not_clobber(sqlite_instance): original_value="hello", ) sqlite_instance.add_conversation_to_memory( - conversation_id=conversation_id, target_identifier=target.get_identifier() + conversation=Conversation(conversation_id=conversation_id, target_identifier=target.get_identifier()) ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[request_piece]) @@ -582,7 +582,9 @@ def test_register_conversation_none_target_does_not_clobber(sqlite_instance): sequence=2, original_value="world", ) - sqlite_instance.add_conversation_to_memory(conversation_id=conversation_id, target_identifier=None) + sqlite_instance.add_conversation_to_memory( + conversation=Conversation(conversation_id=conversation_id, target_identifier=None) + ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece]) metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 10501956cc..392c8cf451 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -193,6 +193,7 @@ async def test_send_prompt_async_empty_response_exception_handled(mock_memory_in async def test_send_prompt_async_request_response_added_to_memory(mock_memory_instance, seed_group): # Use MagicMock with send_prompt_async as AsyncMock to avoid coroutine warnings prompt_target = MagicMock() + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") response = MessagePiece(role="assistant", original_value="test_response").to_message() @@ -281,6 +282,7 @@ async def test_send_prompt_async_mixed_sequence_types(mock_memory_instance): async def test_send_prompt_async_adds_memory_twice(mock_memory_instance, seed_group, response: Message): prompt_target = MagicMock() + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") prompt_target.send_prompt_async = AsyncMock(return_value=[response]) normalizer = PromptNormalizer() @@ -292,6 +294,7 @@ async def test_send_prompt_async_adds_memory_twice(mock_memory_instance, seed_gr async def test_send_prompt_async_no_converters_response(mock_memory_instance, seed_group, response: Message): prompt_target = MagicMock() + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") prompt_target.send_prompt_async = AsyncMock(return_value=[response]) normalizer = PromptNormalizer() @@ -304,6 +307,7 @@ async def test_send_prompt_async_no_converters_response(mock_memory_instance, se async def test_send_prompt_async_converters_response(mock_memory_instance, seed_group, response: Message): prompt_target = MagicMock() + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") prompt_target.send_prompt_async = AsyncMock(return_value=[response]) response_converter = PromptConverterConfiguration(converters=[Base64Converter()], indexes_to_apply=[0]) @@ -322,6 +326,7 @@ async def test_send_prompt_async_converters_response(mock_memory_instance, seed_ async def test_send_prompt_async_image_converter(mock_memory_instance): prompt_target = MagicMock(PromptTarget) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") prompt_target.send_prompt_async = AsyncMock( return_value=[MessagePiece(role="assistant", original_value="response").to_message()] ) From db46ea4a2e5354b802b86def56b087fef9744065 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 10 Jun 2026 12:42:36 -0700 Subject: [PATCH 09/12] Deprecate MemoryInterface.get_conversation; add get_conversation_messages Free up the get_conversation name so it can eventually return the Conversation entity. Phased, non-breaking approach: - get_conversation_messages: new canonical messages-getter (moved body here). - get_conversation: deprecated shim that warns and delegates to get_conversation_messages; removed in 0.17.0. - _get_conversation: renamed from get_conversation_metadata; returns the Conversation entity. Temporary leading underscore (promoted to public get_conversation once the deprecated shim is removed). Migrate all internal callers, tests, and docs to the new names. Add a deprecation test asserting get_conversation warns and matches get_conversation_messages. Also fix latent str | None blast-radius from the MessagePiece.conversation_id change: guard RealtimeTarget and WebSocketCopilotTarget send paths to fail loud when a message has no conversation_id (with focused tests). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../attack/2_red_teaming_attack.ipynb | 2 +- .../executor/attack/2_red_teaming_attack.py | 2 +- .../executor/attack/barge_in_attack.ipynb | 2 +- doc/code/executor/attack/barge_in_attack.py | 2 +- .../memory/2_basic_memory_programming.ipynb | 2 +- doc/code/memory/2_basic_memory_programming.py | 2 +- doc/code/memory/3_memory_data_types.md | 2 +- doc/code/memory/6_azure_sql_memory.ipynb | 2 +- doc/code/memory/6_azure_sql_memory.py | 2 +- doc/code/output/0_output.ipynb | 2 +- doc/code/output/0_output.py | 2 +- pyrit/backend/services/attack_service.py | 8 +-- .../attack/component/conversation_manager.py | 2 +- .../multi_turn/multi_turn_attack_strategy.py | 2 +- .../multi_turn/simulated_conversation.py | 2 +- .../attack/multi_turn/tree_of_attacks.py | 6 +- pyrit/executor/benchmark/fairness_bias.py | 2 +- pyrit/memory/memory_interface.py | 41 ++++++++++--- pyrit/memory/memory_models.py | 2 +- pyrit/output/attack_result/markdown.py | 2 +- pyrit/output/attack_result/pretty.py | 2 +- pyrit/prompt_target/common/prompt_target.py | 4 +- .../openai/openai_realtime_target.py | 5 +- .../prompt_target/websocket_copilot_target.py | 5 +- pyrit/score/conversation_scorer.py | 4 +- pyrit/score/true_false/gandalf_scorer.py | 2 +- .../targets/test_target_filters.py | 8 +-- .../test_sqlite_memory_contract.py | 8 +-- tests/unit/backend/test_attack_service.py | 58 +++++++++---------- .../component/test_simulated_conversation.py | 28 ++++----- .../test_supports_multi_turn_attacks.py | 46 +++++++-------- .../attack/multi_turn/test_tree_of_attacks.py | 4 +- .../test_attack_parameter_consistency.py | 10 ++-- .../attack/test_error_skip_scoring.py | 2 +- .../executor/benchmark/test_fairness_bias.py | 12 ++-- .../test_interface_prompts.py | 45 +++++++++++--- tests/unit/memory/test_azure_sql_memory.py | 4 +- tests/unit/memory/test_sqlite_memory.py | 6 +- .../target/test_azure_ml_chat_target.py | 8 +-- .../prompt_target/target/test_image_target.py | 6 +- .../test_normalize_async_integration.py | 28 ++++----- .../target/test_openai_chat_target.py | 14 ++--- .../target/test_openai_response_target.py | 16 ++--- ...penai_response_target_function_chaining.py | 2 +- .../target/test_prompt_target.py | 16 ++--- .../target/test_realtime_target.py | 14 +++++ .../prompt_target/target/test_tts_target.py | 4 +- .../target/test_websocket_copilot_target.py | 18 ++++-- .../prompt_target/test_prompt_chat_target.py | 2 +- 49 files changed, 279 insertions(+), 191 deletions(-) diff --git a/doc/code/executor/attack/2_red_teaming_attack.ipynb b/doc/code/executor/attack/2_red_teaming_attack.ipynb index 6a84529a08..4c1450b854 100644 --- a/doc/code/executor/attack/2_red_teaming_attack.ipynb +++ b/doc/code/executor/attack/2_red_teaming_attack.ipynb @@ -523,7 +523,7 @@ "\n", "num_turns_to_remove = 2\n", "memory = CentralMemory.get_memory_instance()\n", - "conversation_history = memory.get_conversation(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]\n", + "conversation_history = memory.get_conversation_messages(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]\n", "prepended_conversation = conversation_history\n", "\"\"\"\n", "\n", diff --git a/doc/code/executor/attack/2_red_teaming_attack.py b/doc/code/executor/attack/2_red_teaming_attack.py index 24bcfc6855..5f3d5ae200 100644 --- a/doc/code/executor/attack/2_red_teaming_attack.py +++ b/doc/code/executor/attack/2_red_teaming_attack.py @@ -157,7 +157,7 @@ num_turns_to_remove = 2 memory = CentralMemory.get_memory_instance() -conversation_history = memory.get_conversation(conversation_id=result.conversation_id)[:-num_turns_to_remove*2] +conversation_history = memory.get_conversation_messages(conversation_id=result.conversation_id)[:-num_turns_to_remove*2] prepended_conversation = conversation_history """ diff --git a/doc/code/executor/attack/barge_in_attack.ipynb b/doc/code/executor/attack/barge_in_attack.ipynb index 1e68141cab..dacd029c8b 100644 --- a/doc/code/executor/attack/barge_in_attack.ipynb +++ b/doc/code/executor/attack/barge_in_attack.ipynb @@ -228,7 +228,7 @@ "\n", "# Inspect memory to verify the barge-in landed in metadata.\n", "memory = CentralMemory.get_memory_instance()\n", - "turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id)\n", + "turns = memory.get_conversation_messages(conversation_id=barge_in_result.conversation_id)\n", "print(f\"\\nPersisted pieces ({len(turns)} messages):\")\n", "for message in turns:\n", " for piece in message.message_pieces:\n", diff --git a/doc/code/executor/attack/barge_in_attack.py b/doc/code/executor/attack/barge_in_attack.py index c316e1b184..b4ab4a1212 100644 --- a/doc/code/executor/attack/barge_in_attack.py +++ b/doc/code/executor/attack/barge_in_attack.py @@ -178,7 +178,7 @@ async def barge_in_source(): # Inspect memory to verify the barge-in landed in metadata. memory = CentralMemory.get_memory_instance() -turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id) +turns = memory.get_conversation_messages(conversation_id=barge_in_result.conversation_id) print(f"\nPersisted pieces ({len(turns)} messages):") for message in turns: for piece in message.message_pieces: diff --git a/doc/code/memory/2_basic_memory_programming.ipynb b/doc/code/memory/2_basic_memory_programming.ipynb index e58e4d18fa..e9b773fd04 100644 --- a/doc/code/memory/2_basic_memory_programming.ipynb +++ b/doc/code/memory/2_basic_memory_programming.ipynb @@ -55,7 +55,7 @@ "memory.add_message_to_memory(request=message_list[1].to_message())\n", "memory.add_message_to_memory(request=message_list[2].to_message())\n", "\n", - "entries = memory.get_conversation(conversation_id=conversation_id)\n", + "entries = memory.get_conversation_messages(conversation_id=conversation_id)\n", "\n", "for entry in entries:\n", " print(entry)" diff --git a/doc/code/memory/2_basic_memory_programming.py b/doc/code/memory/2_basic_memory_programming.py index 4aed9bbf88..3926bd5a15 100644 --- a/doc/code/memory/2_basic_memory_programming.py +++ b/doc/code/memory/2_basic_memory_programming.py @@ -42,7 +42,7 @@ memory.add_message_to_memory(request=message_list[1].to_message()) memory.add_message_to_memory(request=message_list[2].to_message()) -entries = memory.get_conversation(conversation_id=conversation_id) +entries = memory.get_conversation_messages(conversation_id=conversation_id) for entry in entries: print(entry) diff --git a/doc/code/memory/3_memory_data_types.md b/doc/code/memory/3_memory_data_types.md index 0fd1b0988d..bd78c52f24 100644 --- a/doc/code/memory/3_memory_data_types.md +++ b/doc/code/memory/3_memory_data_types.md @@ -52,7 +52,7 @@ This rich context allows PyRIT to track the full lifecycle of each interaction, A conversation is a list of `Messages` that share the same `conversation_id`. The sequence of the `MessagePieces` and their corresponding `Messages` dictates the order of the conversation. -A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory.get_conversation_metadata(conversation_id=...)` to retrieve it. +A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory._get_conversation(conversation_id=...)` to retrieve it. Here is a sample conversation made up of three `Messages` which all share the same conversation ID. The first `Message` is the `system` message, followed by a multi-modal `user` prompt with a text `MessagePiece` and an image `MessagePiece`, and finally the `assistant` response in the form of a text `MessagePiece`. diff --git a/doc/code/memory/6_azure_sql_memory.ipynb b/doc/code/memory/6_azure_sql_memory.ipynb index 138673adaf..3d674de8a8 100644 --- a/doc/code/memory/6_azure_sql_memory.ipynb +++ b/doc/code/memory/6_azure_sql_memory.ipynb @@ -177,7 +177,7 @@ "memory.add_message_to_memory(request=Message([message_list[1]]))\n", "memory.add_message_to_memory(request=Message([message_list[2]]))\n", "\n", - "entries = memory.get_conversation(conversation_id=conversation_id)\n", + "entries = memory.get_conversation_messages(conversation_id=conversation_id)\n", "\n", "for entry in entries:\n", " print(entry)" diff --git a/doc/code/memory/6_azure_sql_memory.py b/doc/code/memory/6_azure_sql_memory.py index 75ede398cc..693ea2e8c1 100644 --- a/doc/code/memory/6_azure_sql_memory.py +++ b/doc/code/memory/6_azure_sql_memory.py @@ -77,7 +77,7 @@ memory.add_message_to_memory(request=Message([message_list[1]])) memory.add_message_to_memory(request=Message([message_list[2]])) -entries = memory.get_conversation(conversation_id=conversation_id) +entries = memory.get_conversation_messages(conversation_id=conversation_id) for entry in entries: print(entry) diff --git a/doc/code/output/0_output.ipynb b/doc/code/output/0_output.ipynb index 7651a29b43..3ec202574c 100644 --- a/doc/code/output/0_output.ipynb +++ b/doc/code/output/0_output.ipynb @@ -505,7 +505,7 @@ "from pyrit.output import output_conversation_async\n", "\n", "# get the conversation from memory using the conversation id from the attack result\n", - "conversation = memory.get_conversation(conversation_id=attack_result.conversation_id)\n", + "conversation = memory.get_conversation_messages(conversation_id=attack_result.conversation_id)\n", "\n", "# print the conversation using the print conversation helper\n", "await output_conversation_async(messages=conversation) # type: ignore" diff --git a/doc/code/output/0_output.py b/doc/code/output/0_output.py index a75797e075..10a03840b3 100644 --- a/doc/code/output/0_output.py +++ b/doc/code/output/0_output.py @@ -138,7 +138,7 @@ from pyrit.output import output_conversation_async # get the conversation from memory using the conversation id from the attack result -conversation = memory.get_conversation(conversation_id=attack_result.conversation_id) +conversation = memory.get_conversation_messages(conversation_id=attack_result.conversation_id) # print the conversation using the print conversation helper await output_conversation_async(messages=conversation) # type: ignore diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index cabff15cfe..cd23e1d484 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -267,7 +267,7 @@ async def get_conversation_messages_async( raise ValueError(f"Conversation '{conversation_id}' is not part of attack '{attack_result_id}'") # Get messages for this conversation - pyrit_messages = self._memory.get_conversation(conversation_id=conversation_id) + pyrit_messages = self._memory.get_conversation_messages(conversation_id=conversation_id) backend_messages = await pyrit_messages_to_dto_async(list(pyrit_messages)) return ConversationMessagesResponse( @@ -478,7 +478,7 @@ async def create_related_conversation_async( # --- Branch via duplication (preferred for tracking) --------------- if request.source_conversation_id is not None and request.cutoff_index is not None: - source_metadata = self._memory.get_conversation_metadata(conversation_id=request.source_conversation_id) + source_metadata = self._memory._get_conversation(conversation_id=request.source_conversation_id) new_conversation_id = self._duplicate_conversation_up_to( source_conversation_id=request.source_conversation_id, cutoff_index=request.cutoff_index, @@ -627,7 +627,7 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR labels=attack_labels, # deprecated ) else: - existing_metadata = self._memory.get_conversation_metadata(conversation_id=msg_conversation_id) + existing_metadata = self._memory._get_conversation(conversation_id=msg_conversation_id) await self._store_message_only_async( conversation_id=msg_conversation_id, request=request, @@ -860,7 +860,7 @@ def _duplicate_conversation_up_to( Returns: The new conversation ID containing the duplicated messages. """ - messages = self._memory.get_conversation(conversation_id=source_conversation_id) + messages = self._memory.get_conversation_messages(conversation_id=source_conversation_id) messages_to_copy = [m for m in messages if m.sequence <= cutoff_index] new_conversation_id, all_pieces = self._memory.duplicate_messages(messages=messages_to_copy) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 8e65ba6d0f..4ca1e6ac21 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -214,7 +214,7 @@ def get_conversation(self, conversation_id: str) -> list[Message]: A list of messages in the conversation, ordered by creation time. Returns empty list if no messages exist. """ - conversation = self._memory.get_conversation(conversation_id=conversation_id) + conversation = self._memory.get_conversation_messages(conversation_id=conversation_id) return list(conversation) def get_last_message(self, *, conversation_id: str, role: ChatMessageRole | None = None) -> MessagePiece | None: diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 6ec93a7aa5..db7a44755b 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -136,7 +136,7 @@ def _rotate_conversation_for_single_turn_target( # Duplicate system messages (e.g., system prompt from prepended conversation) # into the new conversation so the target retains its configuration. memory = CentralMemory.get_memory_instance() - messages = memory.get_conversation(conversation_id=old_conversation_id) + messages = memory.get_conversation_messages(conversation_id=old_conversation_id) system_messages = [m for m in messages if m.api_role == "system"] if system_messages: diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index 55db14cb6c..def7f590ee 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -143,7 +143,7 @@ async def generate_simulated_conversation_async( # Extract the conversation from memory and filter for prepended_conversation use memory = CentralMemory.get_memory_instance() - raw_messages = list(memory.get_conversation(conversation_id=result.conversation_id)) + raw_messages = list(memory.get_conversation_messages(conversation_id=result.conversation_id)) # Filter out system messages - keep the actual conversation # System prompts are set separately on each target during attack execution diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 8339f79d29..703278d91d 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -819,7 +819,7 @@ def duplicate(self) -> "_TreeOfAttacksNode": conversation_id=self.objective_target_conversation_id ) else: - messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id) + messages = self._memory.get_conversation_messages(conversation_id=self.objective_target_conversation_id) system_messages = [m for m in messages if m.api_role == "system"] if system_messages: new_id, pieces = self._memory.duplicate_messages(messages=system_messages) @@ -987,7 +987,7 @@ def _is_first_turn(self) -> bool: bool: True if no messages exist in the objective target conversation (first turn), False if the conversation already contains messages (subsequent turns). """ - target_messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id) + target_messages = self._memory.get_conversation_messages(conversation_id=self.objective_target_conversation_id) return not target_messages async def _generate_first_turn_prompt_async(self, objective: str) -> str: @@ -1059,7 +1059,7 @@ async def _generate_subsequent_turn_prompt_async(self, objective: str) -> str: one prior exchange. """ # Get conversation history - target_messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id) + target_messages = self._memory.get_conversation_messages(conversation_id=self.objective_target_conversation_id) # Extract the last assistant response assistant_responses = [r for r in target_messages if r.get_piece().api_role == "assistant"] diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 4e3bfaa505..ef4aba1180 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -242,7 +242,7 @@ def _format_experiment_results( Returns: Dict: dictionary with components from experiment parsed and formatted """ - conversation_pieces = self.memory.get_conversation(conversation_id=attack_result.conversation_id) + conversation_pieces = self.memory.get_conversation_messages(conversation_id=attack_result.conversation_id) response = conversation_pieces[1].get_value() if len(conversation_pieces) >= 2 else "" subject_name = self._extract_name(response) return { diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 7d0f422f62..ca4529144b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -953,7 +953,7 @@ def get_prompt_scores( ) return [entry.get_score() for entry in score_entries] - def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: + def get_conversation_messages(self, *, conversation_id: str) -> MutableSequence[Message]: """ Retrieve a list of Message objects that have the specified conversation ID. @@ -966,7 +966,29 @@ def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: message_pieces = self.get_message_pieces(conversation_id=conversation_id) return group_conversation_message_pieces_by_sequence(message_pieces=message_pieces) - def get_conversation_metadata(self, *, conversation_id: str) -> Conversation | None: + def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: + """ + Retrieve the messages for a conversation (deprecated alias). + + .. deprecated:: + Use ``get_conversation_messages`` instead. The ``get_conversation`` name is + being freed so it can return the conversation entity (currently exposed as + ``_get_conversation``) in a future release. + + Args: + conversation_id (str): The conversation ID to match. + + Returns: + MutableSequence[Message]: A list of chat memory entries with the specified conversation ID. + """ + print_deprecation_message( + old_item="MemoryInterface.get_conversation", + new_item="MemoryInterface.get_conversation_messages", + removed_in="0.17.0", + ) + return self.get_conversation_messages(conversation_id=conversation_id) + + def _get_conversation(self, *, conversation_id: str) -> Conversation | None: """ Return the conversation-scoped metadata stored for ``conversation_id``. @@ -977,6 +999,11 @@ def get_conversation_metadata(self, *, conversation_id: str) -> Conversation | N Conversation | None: The conversation metadata (including the target identifier), or ``None`` if no row exists for the conversation. """ + # NOTE: The leading underscore is temporary. This method returns the conversation + # entity (metadata) and will be promoted to the public ``get_conversation`` once the + # deprecated, messages-returning ``get_conversation`` above is removed in 0.17.0. The + # underscore exists only to avoid colliding with that still-public method during the + # deprecation window. entries = self._query_entries( ConversationEntry, conditions=ConversationEntry.conversation_id == str(conversation_id), @@ -1003,7 +1030,7 @@ def get_request_from_response(self, *, response: Message) -> Message: if response.sequence < 1: raise ValueError("The provided request does not have a preceding request (sequence < 1).") - conversation = self.get_conversation(conversation_id=response.conversation_id) + conversation = self.get_conversation_messages(conversation_id=response.conversation_id) return conversation[response.sequence - 1] def _resolve_attack_id_to_conversation_condition(self, *, attack_id: str | uuid.UUID) -> Any: @@ -1222,8 +1249,8 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: Returns: The uuid for the new conversation. """ - messages = self.get_conversation(conversation_id=conversation_id) - source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + messages = self.get_conversation_messages(conversation_id=conversation_id) + source_metadata = self._get_conversation(conversation_id=conversation_id) source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) if all_pieces: @@ -1246,7 +1273,7 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> Returns: The uuid for the new conversation. """ - messages = self.get_conversation(conversation_id=conversation_id) + messages = self.get_conversation_messages(conversation_id=conversation_id) # remove the final turn from the conversation if len(messages) == 0: @@ -1262,7 +1289,7 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> message for message in messages if message.sequence <= last_message.sequence - length_of_sequence_to_remove ] - source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + source_metadata = self._get_conversation(conversation_id=conversation_id) source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) if all_pieces: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 6ffdd8b8ba..ab467c31a7 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -377,7 +377,7 @@ class ConversationEntry(Base): Holds identifiers that belong to the conversation as a whole -- currently the target identifier -- so they are not duplicated onto every ``PromptMemoryEntry`` row. The target is captured once when the conversation's pieces are written and - read back via ``MemoryInterface.get_conversation_metadata`` (it is not stamped + read back via ``MemoryInterface._get_conversation`` (it is not stamped onto individual pieces). """ diff --git a/pyrit/output/attack_result/markdown.py b/pyrit/output/attack_result/markdown.py index a61becd21a..13b38c323d 100644 --- a/pyrit/output/attack_result/markdown.py +++ b/pyrit/output/attack_result/markdown.py @@ -413,7 +413,7 @@ async def _get_conversation_async(self, conversation_id: str) -> list[Message]: Returns: list[Message]: The conversation messages. """ - return list(self._memory.get_conversation(conversation_id=conversation_id)) + return list(self._memory.get_conversation_messages(conversation_id=conversation_id)) async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ diff --git a/pyrit/output/attack_result/pretty.py b/pyrit/output/attack_result/pretty.py index 8c00e43ef1..db8ff5a65a 100644 --- a/pyrit/output/attack_result/pretty.py +++ b/pyrit/output/attack_result/pretty.py @@ -538,7 +538,7 @@ async def _get_conversation_async(self, conversation_id: str) -> list[Message]: Returns: list[Message]: The conversation messages. """ - return list(self._memory.get_conversation(conversation_id=conversation_id)) + return list(self._memory.get_conversation_messages(conversation_id=conversation_id)) async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index e2d35a9cb6..4bef9e3a26 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -216,7 +216,7 @@ async def _get_normalized_conversation_async(self, *, message: Message) -> list[ """ conversation_id = message.message_pieces[0].conversation_id conversation = ( - list(self._memory.get_conversation(conversation_id=conversation_id)) if conversation_id else [] + list(self._memory.get_conversation_messages(conversation_id=conversation_id)) if conversation_id else [] ) conversation.append(message) normalized = await self.configuration.normalize_async(messages=conversation) @@ -337,7 +337,7 @@ def set_system_prompt( "It must support both multi-turn conversations and editable history." ) - messages = self._memory.get_conversation(conversation_id=conversation_id) + messages = self._memory.get_conversation_messages(conversation_id=conversation_id) if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 0f61c28c4c..f14c83fc15 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -362,7 +362,7 @@ async def send_config_async(self, *, conversation_id: str, conversation: list[Me resolved_conversation = ( conversation if conversation is not None - else list(self._memory.get_conversation(conversation_id=conversation_id)) + else list(self._memory.get_conversation_messages(conversation_id=conversation_id)) ) system_prompt = self._get_system_prompt_from_conversation(conversation=resolved_conversation) config_variables = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) @@ -426,6 +426,9 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me conversation_id = message.message_pieces[0].conversation_id request = message.message_pieces[0] + if not conversation_id: + raise ValueError("RealtimeTarget requires a conversation_id on the message being sent.") + if conversation_id not in self._existing_conversation: connection = await self._connect_async(conversation_id=conversation_id) self._existing_conversation[conversation_id] = connection diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 98b380e35b..d9450ff590 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -616,7 +616,7 @@ def _is_start_of_session(self, *, conversation_id: str) -> bool: Returns: bool: True if no prior messages exist in this conversation, False otherwise. """ - conversation_history = self._memory.get_conversation(conversation_id=conversation_id) + conversation_history = self._memory.get_conversation_messages(conversation_id=conversation_id) return len(conversation_history) == 0 def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tuple[str, str]: @@ -658,6 +658,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me list[Message]: A list containing the response from Copilot. Raises: + ValueError: If the message being sent has no conversation_id. EmptyResponseException: If the response from Copilot is empty. InvalidStatus: If the WebSocket handshake fails with an HTTP status error. RuntimeError: If any other error occurs during WebSocket communication. @@ -665,6 +666,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me message = normalized_conversation[-1] pyrit_conversation_id = message.message_pieces[0].conversation_id + if not pyrit_conversation_id: + raise ValueError("WebSocketCopilotTarget requires a conversation_id on the message being sent.") is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) session_id, copilot_conversation_id = self._generate_consistent_copilot_ids( diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index 8adc403be4..e22b563096 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -62,7 +62,9 @@ async def _score_async(self, message: Message, *, objective: str | None = None) conversation_id = message.message_pieces[0].conversation_id # Retrieve the full conversation from memory using the conversation_id - conversation = self._memory.get_conversation(conversation_id=conversation_id) if conversation_id else [] + conversation = ( + self._memory.get_conversation_messages(conversation_id=conversation_id) if conversation_id else [] + ) if not conversation: raise ValueError(f"Conversation with ID {conversation_id} not found in memory.") diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index e28ad75bbc..1064aaa28e 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -108,7 +108,7 @@ async def _check_for_password_in_conversation_async(self, conversation_id: str) conversation_id=scoring_conversation_id, ) - conversation = self._memory.get_conversation(conversation_id=conversation_id) + conversation = self._memory.get_conversation_messages(conversation_id=conversation_id) if not conversation: raise ValueError(f"Conversation with ID {conversation_id} not found in memory.") diff --git a/tests/integration/targets/test_target_filters.py b/tests/integration/targets/test_target_filters.py index 3eecd5f1fc..25d77be313 100644 --- a/tests/integration/targets/test_target_filters.py +++ b/tests/integration/targets/test_target_filters.py @@ -42,7 +42,7 @@ async def test_azure_content_filters(sqlite_instance, endpoint, api_key, model_n attack = PromptSendingAttack(objective_target=target) result = await attack.execute_async(objective=prompt) assert result is not None - conversation = sqlite_instance.get_conversation(conversation_id=result.conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=result.conversation_id) assert len(conversation) == 2 response = conversation[-1] assert len(response.message_pieces) == 1 @@ -82,7 +82,7 @@ async def test_azure_content_filters_response_api(sqlite_instance, endpoint, api attack = PromptSendingAttack(objective_target=target) result = await attack.execute_async(objective=prompt) assert result is not None - conversation = sqlite_instance.get_conversation(conversation_id=result.conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=result.conversation_id) assert len(conversation) == 2 response = conversation[-1] assert len(response.message_pieces) == 1 @@ -106,7 +106,7 @@ async def test_image_input_filters(sqlite_instance, endpoint, api_key, model_nam attack = PromptSendingAttack(objective_target=target) result = await attack.execute_async(objective=prompt) assert result is not None - conversation = sqlite_instance.get_conversation(conversation_id=result.conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=result.conversation_id) assert len(conversation) == 2 response = conversation[-1] assert len(response.message_pieces) == 1 @@ -132,7 +132,7 @@ async def test_video_input_filters(sqlite_instance, endpoint, api_key, model_nam attack = PromptSendingAttack(objective_target=target) result = await attack.execute_async(objective=prompt) assert result is not None - conversation = sqlite_instance.get_conversation(conversation_id=result.conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=result.conversation_id) assert len(conversation) == 2 response = conversation[-1] assert len(response.message_pieces) == 1 diff --git a/tests/partner_integration/azure_ai_evaluation/test_sqlite_memory_contract.py b/tests/partner_integration/azure_ai_evaluation/test_sqlite_memory_contract.py index 71b580a2f2..6d064c3486 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_sqlite_memory_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_sqlite_memory_contract.py @@ -113,15 +113,15 @@ def test_memory_has_get_prompt_request_pieces_or_equivalent(self): memory.dispose_engine() def test_memory_has_get_conversation(self): - """_callback_chat_target.py calls memory.get_conversation(conversation_id=...).""" + """_callback_chat_target.py calls memory.get_conversation_messages(conversation_id=...).""" memory = SQLiteMemory(db_path=":memory:") - assert hasattr(memory, "get_conversation") - assert callable(memory.get_conversation) + assert hasattr(memory, "get_conversation_messages") + assert callable(memory.get_conversation_messages) memory.dispose_engine() def test_get_conversation_returns_list(self, sqlite_instance): """get_conversation should return a list (empty for unknown conversation_id).""" - result = sqlite_instance.get_conversation(conversation_id="nonexistent-id") + result = sqlite_instance.get_conversation_messages(conversation_id="nonexistent-id") assert isinstance(result, list) def test_get_message_pieces_with_labels_returns_list(self, sqlite_instance): diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index be251ade7f..d467bdac55 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -35,10 +35,10 @@ def mock_memory(): """Create a mock memory instance.""" memory = MagicMock() memory.get_attack_results.return_value = [] - memory.get_conversation.return_value = [] + memory.get_conversation_messages.return_value = [] memory.get_message_pieces.return_value = [] memory.get_conversation_stats.return_value = {} - memory.get_conversation_metadata.return_value = None + memory._get_conversation.return_value = None return memory @@ -550,7 +550,7 @@ async def test_get_attack_returns_attack_details(self, attack_service, mock_memo name="My Attack", ) mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] result = await attack_service.get_attack_async(attack_result_id="test-id") @@ -582,7 +582,7 @@ async def test_get_conversation_messages_returns_messages(self, attack_service, """Test that get_conversation_messages returns messages for existing attack.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] result = await attack_service.get_conversation_messages_async( attack_result_id="test-id", conversation_id="test-id" @@ -868,7 +868,7 @@ async def test_update_attack_updates_outcome_success(self, attack_service, mock_ """Test that update_attack maps 'success' to AttackOutcome.SUCCESS.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] await attack_service.update_attack_async( attack_result_id="test-id", request=UpdateAttackRequest(outcome="success") @@ -883,7 +883,7 @@ async def test_update_attack_updates_outcome_failure(self, attack_service, mock_ """Test that update_attack maps 'failure' to AttackOutcome.FAILURE.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] await attack_service.update_attack_async( attack_result_id="test-id", request=UpdateAttackRequest(outcome="failure") @@ -896,7 +896,7 @@ async def test_update_attack_updates_outcome_undetermined(self, attack_service, """Test that update_attack maps 'undetermined' to AttackOutcome.UNDETERMINED.""" ar = make_attack_result(conversation_id="test-id", outcome=AttackOutcome.SUCCESS) mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] await attack_service.update_attack_async( attack_result_id="test-id", request=UpdateAttackRequest(outcome="undetermined") @@ -909,7 +909,7 @@ async def test_update_attack_updates_outcome_error(self, attack_service, mock_me """Test that update_attack maps 'error' to AttackOutcome.ERROR.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] await attack_service.update_attack_async( attack_result_id="test-id", request=UpdateAttackRequest(outcome="error") @@ -923,7 +923,7 @@ async def test_update_attack_refreshes_updated_at(self, attack_service, mock_mem old_time = datetime(2024, 1, 1, tzinfo=timezone.utc) ar = make_attack_result(conversation_id="test-id", updated_at=old_time) mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] await attack_service.update_attack_async( attack_result_id="test-id", request=UpdateAttackRequest(outcome="success") @@ -962,7 +962,7 @@ async def test_add_message_without_send_stamps_labels_on_pieces(self, attack_ser existing_piece = make_mock_piece(conversation_id="test-id") existing_piece.labels = {"env": "prod"} mock_memory.get_message_pieces.return_value = [existing_piece] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -985,7 +985,7 @@ async def test_add_message_with_send_passes_labels_to_normalizer(self, attack_se existing_piece = make_mock_piece(conversation_id="test-id") existing_piece.labels = {"env": "staging"} mock_memory.get_message_pieces.return_value = [existing_piece] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] with ( patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, @@ -1030,7 +1030,7 @@ async def test_add_message_send_false_without_registry_name_succeeds(self, attac ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="system", @@ -1047,7 +1047,7 @@ async def test_add_message_with_send_sends_via_normalizer(self, attack_service, ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] with ( patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, @@ -1099,7 +1099,7 @@ async def test_add_message_with_converter_ids_gets_converters(self, attack_servi ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] with ( patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, @@ -1144,7 +1144,7 @@ async def test_add_message_raises_when_attack_not_found_after_update(self, attac ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="system", @@ -1162,7 +1162,7 @@ async def test_add_message_raises_when_messages_not_found_after_update(self, att ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="system", @@ -1184,7 +1184,7 @@ async def test_add_message_persists_updated_at_timestamp(self, attack_service, m ar.metadata = {"created_at": "2026-01-01T00:00:00+00:00"} mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -1207,7 +1207,7 @@ async def test_converter_ids_propagate_even_when_preconverted(self, attack_servi ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_converter = MagicMock() mock_converter.get_identifier.return_value = ComponentIdentifier( @@ -1257,7 +1257,7 @@ async def test_add_message_no_existing_pieces_uses_request_labels(self, attack_s ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] # No existing pieces - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -1278,7 +1278,7 @@ async def test_add_message_no_existing_pieces_uses_request_labels_as_is(self, at ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -1298,7 +1298,7 @@ async def test_add_message_no_existing_pieces_no_request_labels(self, attack_ser ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] # No existing pieces - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -1453,7 +1453,7 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic mock_msg = MagicMock() mock_msg.message_pieces = [mock_piece] - mock_memory.get_conversation.return_value = [mock_msg] + mock_memory.get_conversation_messages.return_value = [mock_msg] result = await attack_service.get_conversation_messages_async( attack_result_id="test-id", conversation_id="test-id" @@ -1953,7 +1953,7 @@ async def test_stores_message_in_target_conversation(self, attack_service, mock_ } mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", @@ -2181,7 +2181,7 @@ async def test_create_related_conversation_uses_duplicate_branch(self, attack_se ar = make_attack_result(conversation_id="attack-1") mock_memory.get_attack_results.return_value = [ar] expected_target = ComponentIdentifier(class_name="TextTarget", class_module="pyrit.prompt_target") - mock_memory.get_conversation_metadata.return_value = Conversation( + mock_memory._get_conversation.return_value = Conversation( conversation_id="attack-1", target_identifier=expected_target ) @@ -2372,7 +2372,7 @@ def test_duplicate_conversation_up_to_adds_pieces_when_present(self, attack_serv make_mock_piece(conversation_id="attack-1", sequence=1), make_mock_piece(conversation_id="attack-1", sequence=2), ] - mock_memory.get_conversation.return_value = source_messages + mock_memory.get_conversation_messages.return_value = source_messages duplicated_piece = make_mock_piece(conversation_id="branch-1", sequence=0) mock_memory.duplicate_messages.return_value = ("branch-1", [duplicated_piece]) @@ -2385,7 +2385,7 @@ def test_duplicate_conversation_up_to_adds_pieces_when_present(self, attack_serv def test_duplicate_conversation_up_to_skips_persist_when_no_duplicated_pieces(self, attack_service, mock_memory): """Should not write to memory when duplicate_messages returns no pieces.""" - mock_memory.get_conversation.return_value = [make_mock_piece(conversation_id="attack-1", sequence=0)] + mock_memory.get_conversation_messages.return_value = [make_mock_piece(conversation_id="attack-1", sequence=0)] mock_memory.duplicate_messages.return_value = ("branch-empty", []) new_id = attack_service._duplicate_conversation_up_to(source_conversation_id="attack-1", cutoff_index=10) @@ -2396,7 +2396,7 @@ def test_duplicate_conversation_up_to_skips_persist_when_no_duplicated_pieces(se def test_duplicate_conversation_remaps_assistant_to_simulated(self, attack_service, mock_memory): """Should remap assistant pieces to simulated_assistant when flag is set.""" source = make_mock_piece(conversation_id="attack-1", role="assistant", sequence=0) - mock_memory.get_conversation.return_value = [source] + mock_memory.get_conversation_messages.return_value = [source] dup_piece = make_mock_piece(conversation_id="branch-1", role="assistant", sequence=0) mock_memory.duplicate_messages.return_value = ("branch-1", [dup_piece]) @@ -2450,7 +2450,7 @@ async def test_allows_matching_target(self, attack_service, mock_memory) -> None ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] with ( patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, @@ -2502,7 +2502,7 @@ async def test_allows_matching_operator(self, attack_service, mock_memory) -> No existing_piece = make_mock_piece(conversation_id="test-id") existing_piece.labels = {"operator": "alice"} mock_memory.get_message_pieces.return_value = [existing_piece] - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] request = AddMessageRequest( role="user", diff --git a/tests/unit/executor/attack/component/test_simulated_conversation.py b/tests/unit/executor/attack/component/test_simulated_conversation.py index 7e2f0eddee..66f91423d0 100644 --- a/tests/unit/executor/attack/component/test_simulated_conversation.py +++ b/tests/unit/executor/attack/component/test_simulated_conversation.py @@ -175,7 +175,7 @@ async def test_uses_adversarial_chat_as_simulated_target( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory await generate_simulated_conversation_async( @@ -219,7 +219,7 @@ async def test_creates_attack_with_score_last_turn_only_true( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory await generate_simulated_conversation_async( @@ -262,7 +262,7 @@ async def test_creates_attack_with_correct_max_turns( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory await generate_simulated_conversation_async( @@ -309,7 +309,7 @@ async def test_returns_simulated_conversation_result( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory result = await generate_simulated_conversation_async( @@ -320,8 +320,8 @@ async def test_returns_simulated_conversation_result( num_turns=3, ) - # Verify get_conversation was called with the correct conversation_id - mock_memory.get_conversation.assert_called_once_with(conversation_id=conversation_id) + # Verify get_conversation_messages was called with the correct conversation_id + mock_memory.get_conversation_messages.assert_called_once_with(conversation_id=conversation_id) # Verify the result is a list of SeedPrompts assert isinstance(result, list) @@ -357,7 +357,7 @@ async def test_passes_system_prompt_via_prepended_conversation( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory # Pass a simulated_target_system_prompt_path to test prepending behavior @@ -408,7 +408,7 @@ async def test_passes_memory_labels_to_execute( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory await generate_simulated_conversation_async( @@ -454,7 +454,7 @@ async def test_passes_converter_config_to_attack( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory await generate_simulated_conversation_async( @@ -498,7 +498,7 @@ async def test_prepends_system_message_to_conversation( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory # Pass a simulated_target_system_prompt_path to test prepending behavior @@ -546,7 +546,7 @@ async def test_uses_default_num_turns_of_3( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory # Call without specifying num_turns @@ -604,7 +604,7 @@ async def test_next_message_system_prompt_path_generates_final_user_message( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory # Configure adversarial_chat to return next message response @@ -673,7 +673,7 @@ async def test_next_message_system_prompt_path_sets_system_prompt( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory mock_adversarial_chat.send_prompt_async = AsyncMock(return_value=[next_message_response]) @@ -720,7 +720,7 @@ async def test_starting_sequence_sets_first_sequence_number( with patch("pyrit.executor.attack.multi_turn.simulated_conversation.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = iter(sample_conversation) + mock_memory.get_conversation_messages.return_value = iter(sample_conversation) mock_memory_class.get_memory_instance.return_value = mock_memory result = await generate_simulated_conversation_async( diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index 111c9e00bf..140507baa4 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -139,7 +139,7 @@ def test_system_prompt_duplicated_into_new_conversation(self): assert new_id != old_id memory = CentralMemory.get_memory_instance() - new_messages = memory.get_conversation(conversation_id=new_id) + new_messages = memory.get_conversation_messages(conversation_id=new_id) # Only the system message should be in the new conversation (not the user message) assert len(new_messages) == 1 @@ -162,7 +162,7 @@ def test_system_prompt_preserved_across_multiple_rotations(self): context.executed_turns = turn strategy._rotate_conversation_for_single_turn_target(context=context) - messages = memory.get_conversation(conversation_id=context.session.conversation_id) + messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) system_msgs = [m for m in messages if m.api_role == "system"] assert len(system_msgs) == 1, f"Turn {turn}: expected 1 system message, got {len(system_msgs)}" assert system_msgs[0].get_value() == "You are an expert." @@ -196,7 +196,7 @@ def test_no_system_prompt_yields_fresh_conversation_id(self): strategy._rotate_conversation_for_single_turn_target(context=context) assert context.session.conversation_id != old_id - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) assert len(new_messages) == 0 def test_user_messages_not_carried_over(self): @@ -229,7 +229,7 @@ def test_user_messages_not_carried_over(self): context.executed_turns = 1 strategy._rotate_conversation_for_single_turn_target(context=context) - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) roles = [m.api_role for m in new_messages] assert roles == ["system"], f"Expected only system, got {roles}" @@ -264,7 +264,7 @@ def test_multiple_system_messages_all_carried_over(self): context.executed_turns = 1 strategy._rotate_conversation_for_single_turn_target(context=context) - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) system_values = sorted(m.get_value() for m in new_messages if m.api_role == "system") assert system_values == ["Safety instructions", "System prompt 1"] assert all(m.api_role == "system" for m in new_messages) @@ -281,7 +281,7 @@ def test_empty_conversation_yields_fresh_id(self): assert context.session.conversation_id != old_id memory = CentralMemory.get_memory_instance() - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) assert len(new_messages) == 0 def test_only_system_messages_all_carried_over(self): @@ -301,7 +301,7 @@ def test_only_system_messages_all_carried_over(self): context.executed_turns = 1 strategy._rotate_conversation_for_single_turn_target(context=context) - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) assert len(new_messages) == 1 assert new_messages[0].api_role == "system" assert new_messages[0].get_value() == "Only a system message" @@ -337,7 +337,7 @@ def test_multipiece_system_message_fully_duplicated(self): context.executed_turns = 1 strategy._rotate_conversation_for_single_turn_target(context=context) - new_messages = memory.get_conversation(conversation_id=context.session.conversation_id) + new_messages = memory.get_conversation_messages(conversation_id=context.session.conversation_id) assert len(new_messages) == 1 assert new_messages[0].api_role == "system" # Both pieces should be present in the duplicated message @@ -362,7 +362,7 @@ def test_old_conversation_untouched_after_rotation(self): strategy._rotate_conversation_for_single_turn_target(context=context) # Old conversation should still have both messages intact - old_messages = memory.get_conversation(conversation_id=old_id) + old_messages = memory.get_conversation_messages(conversation_id=old_id) old_roles = [m.api_role for m in old_messages] assert old_roles == ["system", "user"] @@ -437,7 +437,7 @@ def test_single_turn_target_duplicates_only_system_messages(self): assert duplicate.objective_target_conversation_id != node.objective_target_conversation_id # The duplicate's conversation should contain only the system message - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert len(dup_messages) == 1 assert dup_messages[0].api_role == "system" assert dup_messages[0].get_value() == "TAP system prompt" @@ -472,7 +472,7 @@ def test_multi_turn_target_duplicates_full_conversation(self): assert duplicate.objective_target_conversation_id != node.objective_target_conversation_id - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) roles = [m.api_role for m in dup_messages] assert roles == ["system", "user", "assistant"] @@ -493,7 +493,7 @@ def test_single_turn_no_system_messages_yields_fresh_id(self): duplicate = node.duplicate() assert duplicate.objective_target_conversation_id != node.objective_target_conversation_id - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert len(dup_messages) == 0 def test_adversarial_chat_always_fully_duplicated(self): @@ -527,7 +527,7 @@ def test_adversarial_chat_always_fully_duplicated(self): duplicate = node.duplicate() - dup_adv_messages = memory.get_conversation(conversation_id=duplicate.adversarial_chat_conversation_id) + dup_adv_messages = memory.get_conversation_messages(conversation_id=duplicate.adversarial_chat_conversation_id) roles = [m.api_role for m in dup_adv_messages] assert roles == ["system", "user"] @@ -558,7 +558,7 @@ def test_single_turn_multiple_system_messages_all_duplicated(self): duplicate = node.duplicate() - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert all(m.api_role == "system" for m in dup_messages) dup_values = sorted(m.get_value() for m in dup_messages) assert dup_values == ["System prompt A", "System prompt B"] @@ -572,7 +572,7 @@ def test_single_turn_empty_conversation_yields_fresh_id(self): duplicate = node.duplicate() assert duplicate.objective_target_conversation_id != node.objective_target_conversation_id - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert len(dup_messages) == 0 def test_duplicate_node_has_correct_parent_id(self): @@ -615,7 +615,7 @@ def test_system_message_content_preserved_exactly(self): duplicate = node.duplicate() - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert len(dup_messages) == 1 assert dup_messages[0].get_value() == long_prompt @@ -641,7 +641,7 @@ def test_original_conversation_untouched_after_duplicate(self): node.duplicate() # Original conversation should still have both messages - orig_messages = memory.get_conversation(conversation_id=node.objective_target_conversation_id) + orig_messages = memory.get_conversation_messages(conversation_id=node.objective_target_conversation_id) orig_roles = [m.api_role for m in orig_messages] assert orig_roles == ["system", "user"] @@ -673,7 +673,7 @@ def test_single_turn_multipiece_system_message_duplicated(self): duplicate = node.duplicate() - dup_messages = memory.get_conversation(conversation_id=duplicate.objective_target_conversation_id) + dup_messages = memory.get_conversation_messages(conversation_id=duplicate.objective_target_conversation_id) assert len(dup_messages) == 1 assert dup_messages[0].api_role == "system" assert len(dup_messages[0].message_pieces) == 2 @@ -815,7 +815,7 @@ def test_branching_single_turn_target_preserves_system_across_depths(self): # Depth 2: branch (duplicate) — single-turn means only system msg is copied branch1 = node.duplicate() - branch1_msgs = memory.get_conversation(conversation_id=branch1.objective_target_conversation_id) + branch1_msgs = memory.get_conversation_messages(conversation_id=branch1.objective_target_conversation_id) assert len(branch1_msgs) == 1 assert branch1_msgs[0].api_role == "system" assert branch1_msgs[0].get_value() == "You are a red team assistant." @@ -836,13 +836,13 @@ def test_branching_single_turn_target_preserves_system_across_depths(self): memory.add_message_pieces_to_memory(message_pieces=[user2, asst2]) # Verify branch1 now has system + user + assistant - branch1_full = memory.get_conversation(conversation_id=branch1.objective_target_conversation_id) + branch1_full = memory.get_conversation_messages(conversation_id=branch1.objective_target_conversation_id) assert [m.api_role for m in branch1_full] == ["system", "user", "assistant"] # Depth 3: branch again from branch1 branch2 = branch1.duplicate() - branch2_msgs = memory.get_conversation(conversation_id=branch2.objective_target_conversation_id) + branch2_msgs = memory.get_conversation_messages(conversation_id=branch2.objective_target_conversation_id) assert len(branch2_msgs) == 1 assert branch2_msgs[0].api_role == "system" assert branch2_msgs[0].get_value() == "You are a red team assistant." @@ -875,7 +875,7 @@ def test_branching_multi_turn_target_preserves_full_history(self): branch = node.duplicate() - branch_msgs = memory.get_conversation(conversation_id=branch.objective_target_conversation_id) + branch_msgs = memory.get_conversation_messages(conversation_id=branch.objective_target_conversation_id) assert [m.api_role for m in branch_msgs] == ["system", "user", "assistant"] # Add another turn on the branch @@ -889,5 +889,5 @@ def test_branching_multi_turn_target_preserves_full_history(self): # Branch again — should have all 4 messages branch2 = branch.duplicate() - branch2_msgs = memory.get_conversation(conversation_id=branch2.objective_target_conversation_id) + branch2_msgs = memory.get_conversation_messages(conversation_id=branch2.objective_target_conversation_id) assert [m.api_role for m in branch2_msgs] == ["system", "user", "assistant", "user"] diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 49b3fe0c3b..ef18b74714 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -1321,7 +1321,7 @@ async def test_execute_async_with_message_uses_it_for_root_node(self, attack_bui mock_result.auxiliary_scores_summary = {} with patch.object(attack, "_perform_async", return_value=mock_result) as mock_perform: - with patch.object(attack._memory, "get_conversation", return_value=[]): + with patch.object(attack._memory, "get_conversation_messages", return_value=[]): with patch.object(attack._memory, "get_message_pieces", return_value=[]): with patch.object(attack._memory, "add_attack_results_to_memory", return_value=None): result = await attack.execute_async( @@ -1367,7 +1367,7 @@ async def test_execute_async_success_flow(self, attack_builder, helpers): mock_result.auxiliary_scores_summary = {} with patch.object(attack, "_perform_async", return_value=mock_result): - with patch.object(attack._memory, "get_conversation", return_value=[]): + with patch.object(attack._memory, "get_conversation_messages", return_value=[]): with patch.object(attack._memory, "get_message_pieces", return_value=[]): with patch.object(attack._memory, "add_attack_results_to_memory", return_value=None): result = await attack.execute_async(objective="Test objective", memory_labels={"test": "label"}) diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index 963e84d3e4..f0d59ac6e3 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -620,7 +620,7 @@ async def test_prompt_sending_attack_adds_prepended_to_memory( conversation_id = call_args.kwargs.get("conversation_id") memory = CentralMemory.get_memory_instance() - conversation = list(memory.get_conversation(conversation_id=conversation_id)) + conversation = list(memory.get_conversation_messages(conversation_id=conversation_id)) # Should have exactly the prepended messages in memory (mock normalizer doesn't add responses) assert len(conversation) == 2, f"Expected exactly 2 prepended messages, got {len(conversation)}" @@ -653,7 +653,7 @@ async def test_red_teaming_attack_adds_prepended_to_memory( ) memory = CentralMemory.get_memory_instance() - conversation = list(memory.get_conversation(conversation_id=result.conversation_id)) + conversation = list(memory.get_conversation_messages(conversation_id=result.conversation_id)) # Should have exactly the prepended messages in memory (mock normalizer doesn't add responses) assert len(conversation) == 2, f"Expected exactly 2 prepended messages, got {len(conversation)}" @@ -688,7 +688,7 @@ async def test_crescendo_attack_adds_prepended_to_memory( ) memory = CentralMemory.get_memory_instance() - conversation = list(memory.get_conversation(conversation_id=result.conversation_id)) + conversation = list(memory.get_conversation_messages(conversation_id=result.conversation_id)) # Should have exactly the prepended messages in memory (mock normalizer doesn't add responses) assert len(conversation) == 2, f"Expected exactly 2 prepended messages, got {len(conversation)}" @@ -753,7 +753,7 @@ async def test_tap_attack_adds_prepended_to_memory( ) memory = CentralMemory.get_memory_instance() - conversation = list(memory.get_conversation(conversation_id=result.conversation_id)) + conversation = list(memory.get_conversation_messages(conversation_id=result.conversation_id)) # Should have exactly the prepended messages in memory (mock normalizer doesn't add responses) assert len(conversation) == 2, f"Expected exactly 2 prepended messages, got {len(conversation)}" @@ -895,7 +895,7 @@ def _get_adversarial_chat_text_values(*, adversarial_chat_conversation_id: str) List of text values from all text pieces in the adversarial conversation. """ memory = CentralMemory.get_memory_instance() - conversation = list(memory.get_conversation(conversation_id=adversarial_chat_conversation_id)) + conversation = list(memory.get_conversation_messages(conversation_id=adversarial_chat_conversation_id)) text_values = [] for msg in conversation: diff --git a/tests/unit/executor/attack/test_error_skip_scoring.py b/tests/unit/executor/attack/test_error_skip_scoring.py index 82b1ce04ce..211568c60c 100644 --- a/tests/unit/executor/attack/test_error_skip_scoring.py +++ b/tests/unit/executor/attack/test_error_skip_scoring.py @@ -64,7 +64,7 @@ def mock_scorer(): def mock_memory(): """Create a mock memory instance""" memory = MagicMock() - memory.get_conversation.return_value = [] + memory.get_conversation_messages.return_value = [] memory.add_message_to_memory = MagicMock() return memory diff --git a/tests/unit/executor/benchmark/test_fairness_bias.py b/tests/unit/executor/benchmark/test_fairness_bias.py index 8e4b5a2275..3ab9dca608 100644 --- a/tests/unit/executor/benchmark/test_fairness_bias.py +++ b/tests/unit/executor/benchmark/test_fairness_bias.py @@ -256,7 +256,7 @@ async def test_perform_async_calls_prompt_sending_attack( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) @@ -387,7 +387,7 @@ async def test_execute_async_with_required_parameters( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) @@ -416,7 +416,7 @@ async def test_execute_async_with_optional_parameters( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) @@ -454,7 +454,7 @@ async def test_execute_async_multiple_experiments( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) @@ -508,7 +508,7 @@ async def test_full_benchmark_workflow( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) @@ -550,7 +550,7 @@ async def test_benchmark_with_memory_labels( with patch("pyrit.executor.benchmark.fairness_bias.CentralMemory") as mock_memory_class: mock_memory_instance = MagicMock() - mock_memory_instance.get_conversation.return_value = mock_conversation_pieces + mock_memory_instance.get_conversation_messages.return_value = mock_conversation_pieces mock_memory_class.get_memory_instance.return_value = mock_memory_instance benchmark = FairnessBiasBenchmark(objective_target=mock_prompt_target) diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 8554a228d1..04c4ead899 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -604,7 +604,7 @@ def test_add_conversation_to_memory_records_target_for_plain_message_writes(sqli message_pieces=[MessagePiece(role="user", original_value="hi", conversation_id=conversation_id)] ) - metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + metadata = sqlite_instance._get_conversation(conversation_id=conversation_id) assert metadata is not None assert metadata.target_identifier.hash == target_id.hash @@ -630,7 +630,7 @@ def test_message_writes_without_registration_create_no_conversation_row(sqlite_i message_pieces=[MessagePiece(role="user", original_value="hi", conversation_id=conversation_id)] ) - assert sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) is None + assert sqlite_instance._get_conversation(conversation_id=conversation_id) is None # The messages themselves still persist. assert len(sqlite_instance.get_message_pieces(conversation_id=conversation_id)) == 1 @@ -650,7 +650,7 @@ def test_add_conversation_to_memory_same_target_reregister_is_noop(sqlite_instan conversation=Conversation(conversation_id=conversation_id, target_identifier=target) ) - metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + metadata = sqlite_instance._get_conversation(conversation_id=conversation_id) assert metadata is not None assert metadata.target_identifier.hash == target.hash @@ -675,7 +675,7 @@ def test_add_conversation_to_memory_different_target_reregister_raises(sqlite_in ) # The originally recorded target is left untouched. - metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + metadata = sqlite_instance._get_conversation(conversation_id=conversation_id) assert metadata is not None assert metadata.target_identifier.hash == target_a.hash @@ -1348,7 +1348,7 @@ def test_get_request_from_response_success(sqlite_instance: MemoryInterface): sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) # Get the conversation and extract the response - conversation = sqlite_instance.get_conversation(conversation_id=conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=conversation_id) response = conversation[1] # Retrieve the request that produced this response @@ -1360,6 +1360,35 @@ def test_get_request_from_response_success(sqlite_instance: MemoryInterface): assert request.conversation_id == conversation_id +def test_get_conversation_is_deprecated_and_delegates_to_messages(sqlite_instance: MemoryInterface): + """get_conversation warns and returns the same result as get_conversation_messages.""" + conversation_id = str(uuid4()) + pieces = [ + MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + conversation_id=conversation_id, + sequence=0, + ), + MessagePiece( + role="assistant", + original_value="Hi there", + converted_value="Hi there", + conversation_id=conversation_id, + sequence=1, + ), + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + with pytest.warns(DeprecationWarning, match="get_conversation_messages"): + deprecated_result = sqlite_instance.get_conversation(conversation_id=conversation_id) + + expected = sqlite_instance.get_conversation_messages(conversation_id=conversation_id) + assert [m.get_value() for m in deprecated_result] == [m.get_value() for m in expected] + assert len(deprecated_result) == 2 + + def test_get_request_from_response_multi_turn_conversation(sqlite_instance: MemoryInterface): """Test get_request_from_response in a multi-turn conversation.""" conversation_id = str(uuid4()) @@ -1397,7 +1426,7 @@ def test_get_request_from_response_multi_turn_conversation(sqlite_instance: Memo ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - conversation = sqlite_instance.get_conversation(conversation_id=conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=conversation_id) # Test getting request for the second response second_response = conversation[3] @@ -1423,7 +1452,7 @@ def test_get_request_from_response_raises_error_for_non_assistant_role(sqlite_in ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - conversation = sqlite_instance.get_conversation(conversation_id=conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=conversation_id) user_message = conversation[0] with pytest.raises(ValueError, match="The provided request is not a response \\(role must be 'assistant'\\)."): @@ -1446,7 +1475,7 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - conversation = sqlite_instance.get_conversation(conversation_id=conversation_id) + conversation = sqlite_instance.get_conversation_messages(conversation_id=conversation_id) response_without_request = conversation[0] with pytest.raises(ValueError, match="The provided request does not have a preceding request \\(sequence < 1\\)."): diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 7289285218..34e9671461 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -212,7 +212,7 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): memory_interface.add_message_pieces_to_memory(message_pieces=[piece]) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) + retrieved_entries = memory_interface.get_conversation_messages(conversation_id=specific_conversation_id) # Verify that the retrieved entry matches the inserted entry assert len(retrieved_entries) == 1 @@ -228,7 +228,7 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): assert converter_identifiers[0].class_name == "Base64Converter" # The target identifier is conversation-scoped and stored in the Conversations table. - metadata = memory_interface.get_conversation_metadata(conversation_id=specific_conversation_id) + metadata = memory_interface._get_conversation(conversation_id=specific_conversation_id) assert metadata is not None assert metadata.target_identifier.class_name == "TextTarget" diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 1e88fe9927..3a61123c11 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -534,7 +534,7 @@ def test_get_memories_with_json_properties(sqlite_instance): sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) + retrieved_entries = sqlite_instance.get_conversation_messages(conversation_id=specific_conversation_id) # Verify that the retrieved entry matches the inserted entry assert len(retrieved_entries) == 1 @@ -550,7 +550,7 @@ def test_get_memories_with_json_properties(sqlite_instance): assert converter_identifiers[0].class_name == "Base64Converter" # The target identifier is conversation-scoped and stored in the Conversations table. - metadata = sqlite_instance.get_conversation_metadata(conversation_id=specific_conversation_id) + metadata = sqlite_instance._get_conversation(conversation_id=specific_conversation_id) assert metadata is not None assert metadata.target_identifier.class_name == "TextTarget" @@ -587,7 +587,7 @@ def test_register_conversation_none_target_does_not_clobber(sqlite_instance): ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece]) - metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + metadata = sqlite_instance._get_conversation(conversation_id=conversation_id) assert metadata is not None assert metadata.target_identifier is not None assert metadata.target_identifier.class_name == "TextTarget" diff --git a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py index 5358881219..a219a8fab5 100644 --- a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py +++ b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py @@ -107,7 +107,7 @@ async def test_complete_chat_async_bad_json_response(aml_online_chat: AzureMLCha async def test_send_prompt_async_bad_request_error_adds_to_memory(aml_online_chat: AzureMLChatTarget): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() aml_online_chat._memory = mock_memory @@ -123,7 +123,7 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(aml_online_cha with pytest.raises(HTTPStatusError) as bre: await aml_online_chat.send_prompt_async(message=message) - aml_online_chat._memory.get_conversation.assert_called_once_with(conversation_id="123") + aml_online_chat._memory.get_conversation_messages.assert_called_once_with(conversation_id="123") aml_online_chat._memory.add_message_to_memory.assert_called_once_with(request=message) assert str(bre.value) == "Bad Request" @@ -131,7 +131,7 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(aml_online_cha async def test_send_prompt_async_rate_limit_exception_adds_to_memory(aml_online_chat: AzureMLChatTarget): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() aml_online_chat._memory = mock_memory @@ -146,7 +146,7 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory(aml_online_ with pytest.raises(RateLimitException) as rle: await aml_online_chat.send_prompt_async(message=message) - aml_online_chat._memory.get_conversation.assert_called_once_with(conversation_id="123") + aml_online_chat._memory.get_conversation_messages.assert_called_once_with(conversation_id="123") aml_online_chat._memory.add_message_to_memory.assert_called_once_with(request=message) assert str(rle.value) == "Status Code: 429, Message: Rate Limit Exception" diff --git a/tests/unit/prompt_target/target/test_image_target.py b/tests/unit/prompt_target/target/test_image_target.py index 3eca3e7607..20ad7eb3b1 100644 --- a/tests/unit/prompt_target/target/test_image_target.py +++ b/tests/unit/prompt_target/target/test_image_target.py @@ -298,7 +298,7 @@ async def test_send_prompt_async_empty_response_adds_memory( sample_conversations: MutableSequence[MessagePiece], ) -> None: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() request = sample_conversations[0] @@ -324,7 +324,7 @@ async def test_send_prompt_async_rate_limit_adds_memory( sample_conversations: MutableSequence[MessagePiece], ) -> None: mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() request = sample_conversations[0] @@ -492,7 +492,7 @@ async def test_validate_previous_conversations( prior_message = Message(message_pieces=[message_piece]) mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [prior_message] + mock_memory.get_conversation_messages.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() image_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 10d4a447f7..31ff0a3573 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -70,7 +70,7 @@ async def test_openai_chat_target_calls_normalize_async(): user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory mock_completion = _create_mock_chat_completion("world") @@ -99,7 +99,7 @@ async def test_openai_chat_target_sends_normalized_to_construct_request(): adapted_msg = _make_message(role="user", content="adapted") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory mock_completion = _create_mock_chat_completion("response") @@ -152,7 +152,7 @@ async def test_openai_chat_target_memory_not_mutated(): memory_conversation: MutableSequence[Message] = [system_msg] mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = memory_conversation + mock_memory.get_conversation_messages.return_value = memory_conversation target._memory = mock_memory mock_completion = _create_mock_chat_completion("response") @@ -181,7 +181,7 @@ async def test_openai_response_target_calls_normalize_async(): user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory # Mock the API to return a simple response (no tool calls) @@ -222,7 +222,7 @@ async def test_azure_ml_target_calls_normalize_async(): user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory with ( @@ -247,7 +247,7 @@ async def test_azure_ml_target_sends_normalized_to_complete_chat(): adapted_msg = _make_message(role="user", content="adapted") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory with ( @@ -289,7 +289,7 @@ async def test_azure_ml_target_memory_not_mutated(): memory_conversation: MutableSequence[Message] = [system_msg] mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = memory_conversation + mock_memory.get_conversation_messages.return_value = memory_conversation target._memory = mock_memory with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): @@ -326,7 +326,7 @@ async def test_azure_ml_system_squash_via_configuration_pipeline(): user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [system_msg] + mock_memory.get_conversation_messages.return_value = [system_msg] target._memory = mock_memory with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: @@ -359,12 +359,12 @@ async def test_get_normalized_conversation_fetches_history_and_appends_message() user_msg = _make_message(role="user", content="new question") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [history_msg] + mock_memory.get_conversation_messages.return_value = [history_msg] target._memory = mock_memory result = await target._get_normalized_conversation_async(message=user_msg) - mock_memory.get_conversation.assert_called_once_with(conversation_id="conv1") + mock_memory.get_conversation_messages.assert_called_once_with(conversation_id="conv1") assert len(result) == 2 assert result[0].get_value() == "previous answer" assert result[1].get_value() == "new question" @@ -382,7 +382,7 @@ async def test_get_normalized_conversation_empty_history(): user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory result = await target._get_normalized_conversation_async(message=user_msg) @@ -405,7 +405,7 @@ async def test_get_normalized_conversation_does_not_mutate_memory(): memory_list: MutableSequence[Message] = [history_msg] mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = memory_list + mock_memory.get_conversation_messages.return_value = memory_list target._memory = mock_memory await target._get_normalized_conversation_async(message=user_msg) @@ -441,7 +441,7 @@ async def test_get_normalized_conversation_runs_pipeline(): user_msg = _make_message(role="user", content="hi") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [system_msg] + mock_memory.get_conversation_messages.return_value = [system_msg] target._memory = mock_memory result = await target._get_normalized_conversation_async(message=user_msg) @@ -467,7 +467,7 @@ async def test_get_normalized_conversation_passthrough_when_no_adaptation_needed user_msg = _make_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [system_msg] + mock_memory.get_conversation_messages.return_value = [system_msg] target._memory = mock_memory result = await target._get_normalized_conversation_async(message=user_msg) diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 0f7031b6e5..25d3664f31 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -266,7 +266,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j ), ) mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -307,7 +307,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) - target._memory.get_conversation.return_value = [] + target._memory.get_conversation_messages.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -317,7 +317,7 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( target: OpenAIChatTarget, ): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -338,7 +338,7 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( async def test_send_prompt_async_bad_request_error_adds_to_memory(target: OpenAIChatTarget): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -474,7 +474,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) - target._memory.get_conversation.return_value = [] + target._memory.get_conversation_messages.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -628,7 +628,7 @@ def test_is_response_format_json_no_metadata(target: OpenAIChatTarget): async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -681,7 +681,7 @@ async def test_send_prompt_async_other_http_error(patch_central_database): ) message = Message(message_pieces=[message_piece]) target._memory = MagicMock() - target._memory.get_conversation.return_value = [] + target._memory.get_conversation_messages.return_value = [] # Create proper mock request and response for APIStatusError mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index d1e1e9afef..67baff3de8 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -289,7 +289,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory( openai_response_json: dict, target: OpenAIResponseTarget ): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -329,7 +329,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) - target._memory.get_conversation.return_value = [] + target._memory.get_conversation_messages.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -342,7 +342,7 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( target: OpenAIResponseTarget, ): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -356,13 +356,13 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( with pytest.raises(RateLimitException): await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") + target._memory.get_conversation_messages.assert_called_once_with(conversation_id="123") target._memory.add_message_to_memory.assert_called_once_with(request=message) async def test_send_prompt_async_bad_request_error_adds_to_memory(target: OpenAIResponseTarget): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() target._memory = mock_memory @@ -378,7 +378,7 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(target: OpenAI with pytest.raises(BadRequestError): await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") + target._memory.get_conversation_messages.assert_called_once_with(conversation_id="123") target._memory.add_message_to_memory.assert_called_once_with(request=message) @@ -460,7 +460,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) - target._memory.get_conversation.return_value = [] + target._memory.get_conversation_messages.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -984,7 +984,7 @@ async def mock_sdk_create(**kwargs): # Verify intermediate messages were NOT persisted to memory by the target # (The normalizer will handle persistence when messages are returned) - all_messages = target._memory.get_conversation(conversation_id=shared_conversation_id) + all_messages = target._memory.get_conversation_messages(conversation_id=shared_conversation_id) assert len(all_messages) == 0, ( f"Expected 0 messages in memory (target doesn't persist), got {len(all_messages)}" ) diff --git a/tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py b/tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py index 7551c02842..df92486b53 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py +++ b/tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py @@ -245,7 +245,7 @@ async def mock_create(**kwargs): # Verify memory does NOT contain intermediate messages # (targets no longer persist intermediate messages - normalizer handles that) memory = CentralMemory.get_memory_instance() - conversation = memory.get_conversation(conversation_id=conversation_id) + conversation = memory.get_conversation_messages(conversation_id=conversation_id) # Target doesn't persist messages anymore - all returned messages will be persisted by normalizer assert len(conversation) == 0 diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index e1f18b9222..651fe5c770 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -237,7 +237,7 @@ async def test_history_squash_preserves_metadata_on_normalized_message(): user_msg = _make_lineage_message(role="user", content="follow-up question") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [history_msg] + mock_memory.get_conversation_messages.return_value = [history_msg] target._memory = mock_memory normalized = await target._get_normalized_conversation_async(message=user_msg) @@ -282,7 +282,7 @@ async def test_response_preserves_metadata_after_history_squash(): user_msg = _make_lineage_message(role="user", content="follow-up question") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [history_msg] + mock_memory.get_conversation_messages.return_value = [history_msg] target._memory = mock_memory mock_completion = _make_mock_chat_completion("target response") @@ -328,7 +328,7 @@ async def test_system_squash_preserves_metadata(): user_msg = _make_lineage_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [system_msg] + mock_memory.get_conversation_messages.return_value = [system_msg] target._memory = mock_memory normalized = await target._get_normalized_conversation_async(message=user_msg) @@ -379,7 +379,7 @@ async def test_history_squash_propagates_lineage_to_all_pieces(): ) mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [history_msg] + mock_memory.get_conversation_messages.return_value = [history_msg] target._memory = mock_memory normalized = await target._get_normalized_conversation_async(message=user_msg) @@ -414,7 +414,7 @@ async def test_conversation_id_stamped_on_all_but_full_lineage_only_on_last(): user_msg = _make_lineage_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [history_msg] + mock_memory.get_conversation_messages.return_value = [history_msg] target._memory = mock_memory # Simulate a normalizer that inserts a new message with a random conversation_id. @@ -484,7 +484,7 @@ async def test_json_schema_stripped_for_non_schema_target_survives_lineage(): user_msg = Message(message_pieces=[piece]) mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory normalized = await target._get_normalized_conversation_async(message=user_msg) @@ -519,7 +519,7 @@ async def test_json_schema_only_metadata_fully_stripped_survives_lineage(): user_msg = Message(message_pieces=[piece]) mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory normalized = await target._get_normalized_conversation_async(message=user_msg) @@ -542,7 +542,7 @@ async def test_no_warning_when_message_count_unchanged(): user_msg = _make_lineage_message(role="user", content="hello") mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index 53fac1c5ef..019caac2cf 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -242,6 +242,20 @@ async def test_send_prompt_async_invalid_request(target): assert "image_path" in str(excinfo.value) +async def test_send_prompt_to_target_raises_without_conversation_id(target): + message_piece = MessagePiece( + original_value="hello", + original_value_data_type="text", + converted_value="hello", + converted_value_data_type="text", + role="user", + conversation_id=None, + ) + message = Message(message_pieces=[message_piece]) + with pytest.raises(ValueError, match="requires a conversation_id"): + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + async def test_receive_events_empty_output(target: RealtimeTarget): """Test handling of response.done event with empty output array.""" mock_connection = AsyncMock() diff --git a/tests/unit/prompt_target/target/test_tts_target.py b/tests/unit/prompt_target/target/test_tts_target.py index adcdeabf1c..fed2dacfc3 100644 --- a/tests/unit/prompt_target/target/test_tts_target.py +++ b/tests/unit/prompt_target/target/test_tts_target.py @@ -90,7 +90,7 @@ async def test_tts_validate_previous_conversations( prior_message = Message(message_pieces=[message_piece]) mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [prior_message] + mock_memory.get_conversation_messages.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() tts_target._memory = mock_memory @@ -149,7 +149,7 @@ async def test_tts_send_prompt_async_exception_adds_to_memory( exception_class: type[BaseException], ): mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] mock_memory.add_message_to_memory = AsyncMock() tts_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_websocket_copilot_target.py b/tests/unit/prompt_target/target/test_websocket_copilot_target.py index 2904994bb4..936def927a 100644 --- a/tests/unit/prompt_target/target/test_websocket_copilot_target.py +++ b/tests/unit/prompt_target/target/test_websocket_copilot_target.py @@ -121,7 +121,7 @@ def patch_convert_local_image_to_data_url(): @pytest.fixture def mock_memory(): memory = MagicMock() - memory.get_conversation.return_value = [] + memory.get_conversation_messages.return_value = [] memory.add_message_to_memory = AsyncMock() return memory @@ -751,21 +751,21 @@ def test_is_start_of_session_with_empty_history(self, mock_authenticator): target = WebSocketCopilotTarget(authenticator=mock_authenticator) mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] + mock_memory.get_conversation_messages.return_value = [] target._memory = mock_memory conversation_id = "test_conv_123" result = target._is_start_of_session(conversation_id=conversation_id) assert result is True - mock_memory.get_conversation.assert_called_once_with(conversation_id=conversation_id) + mock_memory.get_conversation_messages.assert_called_once_with(conversation_id=conversation_id) def test_is_start_of_session_with_existing_history(self, mock_authenticator): target = WebSocketCopilotTarget(authenticator=mock_authenticator) mock_memory = MagicMock() mock_message = MagicMock() - mock_memory.get_conversation.return_value = [mock_message] + mock_memory.get_conversation_messages.return_value = [mock_message] target._memory = mock_memory conversation_id = "test_conv_123" @@ -827,6 +827,16 @@ async def test_send_prompt_async_successful(self, mock_authenticator, make_messa assert responses[0].message_pieces[0].converted_value == "Response from Copilot" assert responses[0].message_pieces[0].api_role == "assistant" + async def test_send_prompt_to_target_raises_without_conversation_id( + self, mock_authenticator, make_message_piece, mock_memory + ): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + target._memory = mock_memory + message = Message(message_pieces=[make_message_piece("Hello", conversation_id=None)]) + + with pytest.raises(ValueError, match="requires a conversation_id"): + await target._send_prompt_to_target_async(normalized_conversation=[message]) + async def test_send_prompt_async_with_exceptions(self, mock_authenticator, make_message_piece, mock_memory): from pyrit.exceptions import EmptyResponseException diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 9caac8497a..5be57b43db 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -157,7 +157,7 @@ def test_set_system_prompt_writes_system_message_when_capabilities_present(): conversation_id=conversation_id, ) - messages = target._memory.get_conversation(conversation_id=conversation_id) + messages = target._memory.get_conversation_messages(conversation_id=conversation_id) assert len(messages) == 1 pieces = messages[0].message_pieces assert len(pieces) == 1 From d6a84dfd7c92a0de7e3c6336b6425fdb5a4c1e31 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 10 Jun 2026 16:27:25 -0700 Subject: [PATCH 10/12] Fix ty errors from conversation_id str | None change - conversations.py: widen grouping dict key to str | None to match MessagePiece.conversation_id - fuzzer.py: fail loud (ValueError) when a scored response piece lacks a conversation_id before recording it as a jailbreak conversation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/promptgen/fuzzer/fuzzer.py | 8 +++++++- pyrit/models/messages/conversations.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 492ae2ee62..43674810fa 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -1077,6 +1077,9 @@ def _process_scoring_results( Returns: int: The number of jailbreaks found. + + Raises: + ValueError: If a scored response piece has no conversation_id. """ jailbreak_count = 0 response_pieces = [response.message_pieces[0] for response in responses] @@ -1084,7 +1087,10 @@ def _process_scoring_results( for index, score in enumerate(scores): if self._is_jailbreak(score): jailbreak_count += 1 - context.jailbreak_conversation_ids.append(response_pieces[index].conversation_id) + conversation_id = response_pieces[index].conversation_id + if conversation_id is None: + raise ValueError("Response piece has no conversation_id; cannot record jailbreak conversation.") + context.jailbreak_conversation_ids.append(conversation_id) # Update tracking context.total_jailbreak_count += jailbreak_count diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index e4cb34a121..9e829a28c8 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -158,7 +158,7 @@ def group_message_pieces_into_conversations( return [] # Group pieces by conversation ID - conversations: dict[str, list[MessagePiece]] = {} + conversations: dict[str | None, list[MessagePiece]] = {} for piece in message_pieces: conv_id = piece.conversation_id if conv_id not in conversations: From 6b61ab01a51fb77cfe67a18c100404d7cff41251 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 11 Jun 2026 12:30:03 -0700 Subject: [PATCH 11/12] Fix false-positive labels deprecation warning on empty dict Treat a falsy/empty deprecated kwarg value (e.g. labels={}, the field default) as not supplied so internal sites forwarding labels=.labels on the happy path do not trip a spurious DeprecationWarning. Real values still warn. Adds regression tests for the empty-dict and construct_response_from_request cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/messages/message_piece.py | 8 +++++++- tests/unit/models/test_message_piece.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 64dc9c0ed0..2e7de96b16 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -93,13 +93,19 @@ def _warn_on_deprecated_kwargs(cls, data: Any) -> Any: """ Emit DeprecationWarning for each deprecated kwarg explicitly passed. + Only a truthy value counts as "passed". An empty/falsy value (e.g. + ``labels={}``, the field default) is treated as not supplied, so callers + that forward ``labels=.labels`` on the happy path do not trip a + spurious warning. This matches the post-construction assignment pattern + used elsewhere (``piece.labels = labels`` guarded by ``if labels:``). + Returns: The (unchanged) input ``data`` so validation can continue. """ if not isinstance(data, dict): return data for kwarg, removed_in in _DEPRECATED_KWARGS: - if data.get(kwarg) is not None: + if data.get(kwarg): print_deprecation_message( old_item=f"MessagePiece(..., {kwarg}=...)", new_item="MessagePiece(...)", diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index e01b83e812..84d9ff40c3 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -1108,6 +1108,29 @@ def test_labels_omitted_no_warning(self): msgs = self._emit_deprecation_msgs() assert not any("labels" in str(m.message) for m in msgs) + def test_labels_empty_dict_no_warning(self): + """An explicit empty ``labels={}`` (the field default) must not warn. + + Internal call sites forward ``labels=.labels`` which is ``{}`` on + the happy path; this regression-guards that such forwarding stays silent. + """ + msgs = self._emit_deprecation_msgs(labels={}) + assert not any("labels" in str(m.message) for m in msgs) + + def test_construct_response_from_request_default_labels_no_warning(self): + """``construct_response_from_request`` on a request with default labels is silent. + + Reproduces the reported false positive: every response construction warned + because the request's default ``labels={}`` was forwarded through the + ``MessagePiece`` constructor. + """ + request = MessagePiece(role="user", original_value="hello", conversation_id="conv-1") + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + construct_response_from_request(request=request, response_text_pieces=["hi"]) + deprecation_msgs = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert not any("labels" in str(m.message) for m in deprecation_msgs) + def test_memory_load_roundtrip_does_not_emit_deprecation_warnings(self) -> None: """Reconstructing a MessagePiece from PromptMemoryEntry must not emit deprecations. From 2bd042022d768ed0fe73ccacfcdac37403b3b3e9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 11 Jun 2026 13:00:41 -0700 Subject: [PATCH 12/12] Fix false-positive labels deprecation warning in backend mappers Apply the same truthy-guard fix to request_to_pyrit_message and request_piece_to_pyrit_message_piece: an empty/falsy labels (e.g. {} forwarded on the happy path by _resolve_labels) is treated as not supplied, so no spurious DeprecationWarning fires. Real labels still warn. Also align the merged TestPyritMessagesToDtoRealObjects tests with the caller-owned conversation_id validation (set conversation_id on pieces) and the get_conversation -> get_conversation_messages rename. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/attack_mappers.py | 10 ++++- tests/unit/backend/test_mappers.py | 54 ++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 49f025336e..d408acfa54 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -536,7 +536,10 @@ def request_piece_to_pyrit_message_piece( Returns: PyritMessagePiece domain object. """ - if labels is not None: + # Only a truthy value counts as "passed"; an empty/falsy ``labels`` (e.g. {} + # forwarded on the happy path) is treated as not supplied to avoid a spurious + # warning. Matches MessagePiece's deprecated-kwarg guard. + if labels: print_deprecation_message( old_item="request_piece_to_pyrit_message_piece(..., labels=...)", new_item="request_piece_to_pyrit_message_piece(...)", @@ -582,7 +585,10 @@ def request_to_pyrit_message( Returns: PyritMessage ready to send to the target. """ - if labels is not None: + # Only a truthy value counts as "passed"; an empty/falsy ``labels`` (e.g. {} + # forwarded on the happy path) is treated as not supplied to avoid a spurious + # warning. Matches MessagePiece's deprecated-kwarg guard. + if labels: print_deprecation_message( old_item="request_to_pyrit_message(..., labels=...)", new_item="request_to_pyrit_message(...)", diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 17a33c073c..4a4e7aa1e8 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -679,7 +679,7 @@ async def test_scores_are_fetched_from_memory_and_attached(self, sqlite_instance from pyrit.models import MessagePiece as RealPyritMessagePiece from pyrit.models import Score as RealPyritScore - piece = RealPyritMessagePiece(role="user", original_value="hi") + piece = RealPyritMessagePiece(role="user", original_value="hi", conversation_id="real-conv-scores") sqlite_instance.add_message_to_memory(request=RealPyritMessage(message_pieces=[piece])) score = RealPyritScore( @@ -691,7 +691,7 @@ async def test_scores_are_fetched_from_memory_and_attached(self, sqlite_instance ) sqlite_instance.add_scores_to_memory(scores=[score]) - reloaded = sqlite_instance.get_conversation(conversation_id=piece.conversation_id) + reloaded = sqlite_instance.get_conversation_messages(conversation_id=piece.conversation_id) result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) assert len(result) == 1 @@ -709,10 +709,10 @@ async def test_empty_scores_when_none_recorded(self, sqlite_instance) -> None: from pyrit.models import Message as RealPyritMessage from pyrit.models import MessagePiece as RealPyritMessagePiece - piece = RealPyritMessagePiece(role="user", original_value="hi") + piece = RealPyritMessagePiece(role="user", original_value="hi", conversation_id="real-conv-empty") sqlite_instance.add_message_to_memory(request=RealPyritMessage(message_pieces=[piece])) - reloaded = sqlite_instance.get_conversation(conversation_id=piece.conversation_id) + reloaded = sqlite_instance.get_conversation_messages(conversation_id=piece.conversation_id) result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) assert result[0].pieces[0].scores == [] @@ -746,7 +746,7 @@ async def test_scores_are_grouped_per_piece_across_multiple_pieces(self, sqlite_ ] ) - reloaded = sqlite_instance.get_conversation(conversation_id=conv_id) + reloaded = sqlite_instance.get_conversation_messages(conversation_id=conv_id) result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) by_role = {msg.role: msg for msg in result} @@ -928,6 +928,29 @@ def test_labels_emit_deprecation_warning(self) -> None: assert mock_deprecation.call_count == 2 + def test_empty_labels_no_deprecation_warning(self) -> None: + """An explicit empty ``labels={}`` (forwarded on the happy path) must not warn.""" + request = MagicMock() + request.role = "user" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.prompt_metadata = None + piece.mime_type = None + piece.original_prompt_id = None + request.pieces = [piece] + + with patch("pyrit.backend.mappers.attack_mappers.print_deprecation_message") as mock_deprecation: + request_to_pyrit_message( + request=request, + conversation_id="conv-1", + sequence=0, + labels={}, + ) + + mock_deprecation.assert_not_called() + class TestRequestPieceToPyritMessagePiece: """Tests for request_piece_to_pyrit_message_piece function.""" @@ -1069,6 +1092,27 @@ def test_labels_emit_deprecation_warning(self) -> None: mock_deprecation.assert_called_once() + def test_empty_labels_no_deprecation_warning(self) -> None: + """An explicit empty ``labels={}`` (forwarded on the happy path) must not warn.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.mime_type = None + piece.prompt_metadata = None + piece.original_prompt_id = None + + with patch("pyrit.backend.mappers.attack_mappers.print_deprecation_message") as mock_deprecation: + request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + labels={}, + ) + + mock_deprecation.assert_not_called() + def test_labels_default_to_empty_dict(self) -> None: """Test that labels default to empty dict when not provided.""" piece = MagicMock()