Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
72f02b6
Introduce Conversation model; move target identifier off MessagePiece
rlundeen2 Jun 7, 2026
a7e5fe1
Fix tests and docs for Conversation model cutover
rlundeen2 Jun 7, 2026
e7a6918
Merge remote-tracking branch 'origin/main' into rlundeen2/conversatio…
rlundeen2 Jun 7, 2026
8e3c549
Fix ConversationEntry docstring and add None-clobber upsert test
rlundeen2 Jun 7, 2026
3ad8d75
Make conversation_id caller-owned and harden memory choke point
rlundeen2 Jun 9, 2026
cec9ccd
Add explicit conversation registration; drop target_identifier from m…
rlundeen2 Jun 9, 2026
63594c7
Rename _add_message_pieces_to_storage to _add_message_pieces_to_memor…
rlundeen2 Jun 9, 2026
1946f74
Merge remote-tracking branch 'origin/main' into rlundeen2/conversatio…
rlundeen2 Jun 9, 2026
5aac569
Add coverage for conversation migration, upsert, and attack_identifie…
rlundeen2 Jun 9, 2026
4ae9bbc
Use Conversation object param and insert-only conversation registration
rlundeen2 Jun 10, 2026
2538d0b
Merge origin/main into rlundeen2/conversation-model
rlundeen2 Jun 10, 2026
db46ea4
Deprecate MemoryInterface.get_conversation; add get_conversation_mess…
rlundeen2 Jun 10, 2026
d6a84df
Fix ty errors from conversation_id str | None change
rlundeen2 Jun 10, 2026
6b61ab0
Fix false-positive labels deprecation warning on empty dict
rlundeen2 Jun 11, 2026
146a803
Merge remote-tracking branch 'origin/main' into rlundeen2/conversatio…
rlundeen2 Jun 11, 2026
2bd0420
Fix false-positive labels deprecation warning in backend mappers
rlundeen2 Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/code/executor/attack/2_red_teaming_attack.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@
"\n",
"num_turns_to_remove = 2\n",
"memory = CentralMemory.get_memory_instance()\n",
"conversation_history = memory.get_conversation(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]\n",
"conversation_history = memory.get_conversation_messages(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]\n",
"prepended_conversation = conversation_history\n",
"\"\"\"\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/code/executor/attack/2_red_teaming_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@

num_turns_to_remove = 2
memory = CentralMemory.get_memory_instance()
conversation_history = memory.get_conversation(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]
conversation_history = memory.get_conversation_messages(conversation_id=result.conversation_id)[:-num_turns_to_remove*2]
prepended_conversation = conversation_history
"""

Expand Down
2 changes: 1 addition & 1 deletion doc/code/executor/attack/barge_in_attack.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"\n",
"# Inspect memory to verify the barge-in landed in metadata.\n",
"memory = CentralMemory.get_memory_instance()\n",
"turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id)\n",
"turns = memory.get_conversation_messages(conversation_id=barge_in_result.conversation_id)\n",
"print(f\"\\nPersisted pieces ({len(turns)} messages):\")\n",
"for message in turns:\n",
" for piece in message.message_pieces:\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/code/executor/attack/barge_in_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def barge_in_source():

# Inspect memory to verify the barge-in landed in metadata.
memory = CentralMemory.get_memory_instance()
turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id)
turns = memory.get_conversation_messages(conversation_id=barge_in_result.conversation_id)
print(f"\nPersisted pieces ({len(turns)} messages):")
for message in turns:
for piece in message.message_pieces:
Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/2_basic_memory_programming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"memory.add_message_to_memory(request=message_list[1].to_message())\n",
"memory.add_message_to_memory(request=message_list[2].to_message())\n",
"\n",
"entries = memory.get_conversation(conversation_id=conversation_id)\n",
"entries = memory.get_conversation_messages(conversation_id=conversation_id)\n",
"\n",
"for entry in entries:\n",
" print(entry)"
Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/2_basic_memory_programming.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
memory.add_message_to_memory(request=message_list[1].to_message())
memory.add_message_to_memory(request=message_list[2].to_message())

entries = memory.get_conversation(conversation_id=conversation_id)
entries = memory.get_conversation_messages(conversation_id=conversation_id)

for entry in entries:
print(entry)
4 changes: 2 additions & 2 deletions doc/code/memory/3_memory_data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../
- **`labels`**: Dictionary of labels for categorization and filtering
- **`prompt_metadata`**: Component-specific metadata (e.g., blob URIs, document types)
- **`converter_identifiers`**: List of converters applied to transform the prompt
- **`prompt_target_identifier`**: Information about the target that received this prompt
- **`attack_identifier`**: Information about the attack that generated this prompt
- **`scorer_identifier`**: Information about the scorer that evaluated this prompt
- **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`)
- **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`)
Expand Down Expand Up @@ -54,6 +52,8 @@ This rich context allows PyRIT to track the full lifecycle of each interaction,

A conversation is a list of `Messages` that share the same `conversation_id`. The sequence of the `MessagePieces` and their corresponding `Messages` dictates the order of the conversation.

A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory._get_conversation(conversation_id=...)` to retrieve it.

Here is a sample conversation made up of three `Messages` which all share the same conversation ID. The first `Message` is the `system` message, followed by a multi-modal `user` prompt with a text `MessagePiece` and an image `MessagePiece`, and finally the `assistant` response in the form of a text `MessagePiece`.

```{mermaid}
Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/6_azure_sql_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@
"memory.add_message_to_memory(request=Message([message_list[1]]))\n",
"memory.add_message_to_memory(request=Message([message_list[2]]))\n",
"\n",
"entries = memory.get_conversation(conversation_id=conversation_id)\n",
"entries = memory.get_conversation_messages(conversation_id=conversation_id)\n",
"\n",
"for entry in entries:\n",
" print(entry)"
Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/6_azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
memory.add_message_to_memory(request=Message([message_list[1]]))
memory.add_message_to_memory(request=Message([message_list[2]]))

entries = memory.get_conversation(conversation_id=conversation_id)
entries = memory.get_conversation_messages(conversation_id=conversation_id)

for entry in entries:
print(entry)
7 changes: 5 additions & 2 deletions doc/code/memory/9_schema_diagram.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ flowchart LR
P_labels["labels (VARCHAR)"]
P_prompt_metadata["prompt_metadata (VARCHAR)"]
P_converter_identifiers["converter_identifiers (VARCHAR)"]
P_prompt_target_identifier["prompt_target_identifier (VARCHAR)"]
P_attack_identifier["attack_identifier (VARCHAR)"]
P_response_error["response_error (VARCHAR)"]
P_converted_value_data_type["converted_value_data_type (VARCHAR)"]
P_converted_value["converted_value (VARCHAR)"]
P_converted_value_sha256["converted_value_sha256 (VARCHAR)"]
P_original_prompt_id["original_prompt_id (UUID)"]
end
subgraph Conversations["Conversations"]
C_conversation_id["conversation_id (VARCHAR)"]
C_target_identifier["target_identifier (VARCHAR)"]
end
subgraph ScoreEntries["ScoreEntries"]
Sc_id["id (UUID)"]
Sc_prompt_request_response_id["prompt_request_response_id (VARCHAR)"]
Expand All @@ -63,6 +65,7 @@ flowchart LR
end
S_value_sha256 -- N:N relationship to query --> P_original_value_sha256
P_id -- 1:N relationship to query --> Sc_prompt_request_response_id
P_conversation_id -- N:1 relationship to query --> C_conversation_id

style S_value_sha256 fill:#ff8800ff
style P_id fill:#14a519ff
Expand Down
2 changes: 1 addition & 1 deletion doc/code/output/0_output.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
"from pyrit.output import output_conversation_async\n",
"\n",
"# get the conversation from memory using the conversation id from the attack result\n",
"conversation = memory.get_conversation(conversation_id=attack_result.conversation_id)\n",
"conversation = memory.get_conversation_messages(conversation_id=attack_result.conversation_id)\n",
"\n",
"# print the conversation using the print conversation helper\n",
"await output_conversation_async(messages=conversation) # type: ignore"
Expand Down
2 changes: 1 addition & 1 deletion doc/code/output/0_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
from pyrit.output import output_conversation_async

# get the conversation from memory using the conversation id from the attack result
conversation = memory.get_conversation(conversation_id=attack_result.conversation_id)
conversation = memory.get_conversation_messages(conversation_id=attack_result.conversation_id)

# print the conversation using the print conversation helper
await output_conversation_async(messages=conversation) # type: ignore
Expand Down
10 changes: 8 additions & 2 deletions pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,10 @@ def request_piece_to_pyrit_message_piece(
Returns:
PyritMessagePiece domain object.
"""
if labels is not None:
# Only a truthy value counts as "passed"; an empty/falsy ``labels`` (e.g. {}
# forwarded on the happy path) is treated as not supplied to avoid a spurious
# warning. Matches MessagePiece's deprecated-kwarg guard.
if labels:
print_deprecation_message(
old_item="request_piece_to_pyrit_message_piece(..., labels=...)",
new_item="request_piece_to_pyrit_message_piece(...)",
Expand Down Expand Up @@ -582,7 +585,10 @@ def request_to_pyrit_message(
Returns:
PyritMessage ready to send to the target.
"""
if labels is not None:
# Only a truthy value counts as "passed"; an empty/falsy ``labels`` (e.g. {}
# forwarded on the happy path) is treated as not supplied to avoid a spurious
# warning. Matches MessagePiece's deprecated-kwarg guard.
if labels:
print_deprecation_message(
old_item="request_to_pyrit_message(..., labels=...)",
new_item="request_to_pyrit_message(...)",
Expand Down
28 changes: 26 additions & 2 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
AttackOutcome,
AttackResult,
ComponentIdentifier,
Conversation,
ConversationStats,
ConversationType,
MessagePiece,
Expand Down Expand Up @@ -266,7 +267,7 @@ async def get_conversation_messages_async(
raise ValueError(f"Conversation '{conversation_id}' is not part of attack '{attack_result_id}'")

# Get messages for this conversation
pyrit_messages = self._memory.get_conversation(conversation_id=conversation_id)
pyrit_messages = self._memory.get_conversation_messages(conversation_id=conversation_id)
backend_messages = await pyrit_messages_to_dto_async(list(pyrit_messages), memory=self._memory)

return ConversationMessagesResponse(
Expand Down Expand Up @@ -312,6 +313,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
cutoff_index=request.cutoff_index,
labels_override=labels,
remap_assistant_to_simulated=True,
target_identifier=target_identifier,
)
else:
conversation_id = str(uuid.uuid4())
Expand Down Expand Up @@ -344,6 +346,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
conversation_id=conversation_id,
prepended=request.prepended_conversation,
labels=labels, # deprecated
target_identifier=target_identifier,
)

return CreateAttackResponse(
Expand Down Expand Up @@ -475,9 +478,11 @@ async def create_related_conversation_async(

# --- Branch via duplication (preferred for tracking) ---------------
if request.source_conversation_id is not None and request.cutoff_index is not None:
source_metadata = self._memory._get_conversation(conversation_id=request.source_conversation_id)
new_conversation_id = self._duplicate_conversation_up_to(
source_conversation_id=request.source_conversation_id,
cutoff_index=request.cutoff_index,
target_identifier=source_metadata.target_identifier if source_metadata else None,
)
else:
new_conversation_id = str(uuid.uuid4())
Expand Down Expand Up @@ -622,11 +627,13 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR
labels=attack_labels, # deprecated
)
else:
existing_metadata = self._memory._get_conversation(conversation_id=msg_conversation_id)
await self._store_message_only_async(
conversation_id=msg_conversation_id,
request=request,
sequence=sequence,
labels=attack_labels, # deprecated
target_identifier=existing_metadata.target_identifier if existing_metadata else None,
)

await self._update_attack_after_message_async(attack_result_id=attack_result_id, ar=ar, request=request)
Expand Down Expand Up @@ -828,6 +835,7 @@ def _duplicate_conversation_up_to(
cutoff_index: int,
labels_override: dict[str, str] | None = None,
remap_assistant_to_simulated: bool = False,
target_identifier: ComponentIdentifier | None = None,
) -> str:
"""
Duplicate messages from a conversation up to and including a turn index.
Expand All @@ -846,10 +854,13 @@ def _duplicate_conversation_up_to(
``assistant`` are changed to ``simulated_assistant`` so the
branched context is inert and won't confuse the target.

target_identifier (ComponentIdentifier | None): The target the new conversation
is held with, if known. Recorded once for the duplicated conversation.

Returns:
The new conversation ID containing the duplicated messages.
"""
messages = self._memory.get_conversation(conversation_id=source_conversation_id)
messages = self._memory.get_conversation_messages(conversation_id=source_conversation_id)
messages_to_copy = [m for m in messages if m.sequence <= cutoff_index]

new_conversation_id, all_pieces = self._memory.duplicate_messages(messages=messages_to_copy)
Expand All @@ -865,6 +876,9 @@ def _duplicate_conversation_up_to(
piece.role = "simulated_assistant"

if all_pieces:
self._memory.add_conversation_to_memory(
conversation=Conversation(conversation_id=new_conversation_id, target_identifier=target_identifier)
)
self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces))

return new_conversation_id
Expand Down Expand Up @@ -953,8 +967,14 @@ async def _store_prepended_messages_async(
conversation_id: str,
prepended: list[Any],
labels: dict[str, str] | None = None, # deprecated
target_identifier: ComponentIdentifier | None = None,
) -> None:
"""Store prepended conversation messages in memory."""
if not prepended:
return
self._memory.add_conversation_to_memory(
conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier)
)
for seq, msg in enumerate(prepended):
for p in msg.pieces:
piece = request_piece_to_pyrit_message_piece(
Expand Down Expand Up @@ -1010,9 +1030,13 @@ async def _store_message_only_async(
request: AddMessageRequest,
sequence: int,
labels: dict[str, str] | None = None, # deprecated
target_identifier: ComponentIdentifier | None = None,
) -> None:
"""Store message without sending (send=False)."""
await self._persist_base64_pieces_async(request)
self._memory.add_conversation_to_memory(
conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier)
)
for p in request.pieces:
piece = request_piece_to_pyrit_message_piece(
piece=p,
Expand Down
Loading
Loading