Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def get_llm_backed_user_simulator_prompt(
"""Formats the prompt for the llm-backed user simulator"""
from jinja2 import DictLoader
from jinja2 import pass_context
from jinja2 import Template
from jinja2.sandbox import SandboxedEnvironment

templates = {
Expand All @@ -200,7 +199,7 @@ def get_llm_backed_user_simulator_prompt(
def _render_string_filter(context, template_string):
if not template_string:
return ""
return Template(template_string).render(context)
return template_env.from_string(template_string).render(context.get_all())

template_env.filters["render_string_filter"] = _render_string_filter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ def get_per_turn_user_simulator_quality_prompt(
):
"""Formats the prompt for the per turn user simulator evaluator"""
from jinja2 import DictLoader
from jinja2 import Environment
from jinja2 import pass_context
from jinja2 import Template
from jinja2.sandbox import SandboxedEnvironment

templates = {
"verifier_instructions": (
Expand All @@ -232,13 +231,13 @@ def get_per_turn_user_simulator_quality_prompt(
)
),
}
template_env = Environment(loader=DictLoader(templates))
template_env = SandboxedEnvironment(loader=DictLoader(templates))

@pass_context
def _render_string_filter(context, template_string):
if not template_string:
return ""
return Template(template_string).render(context)
return template_env.from_string(template_string).render(context.get_all())

template_env.filters["render_string_filter"] = _render_string_filter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.adk.evaluation.simulation.llm_backed_user_simulator_prompts import is_valid_user_simulator_template
from google.adk.evaluation.simulation.user_simulator_personas import UserBehavior
from google.adk.evaluation.simulation.user_simulator_personas import UserPersona
from jinja2.exceptions import SecurityError
import pytest

_MOCK_DEFAULT_TEMPLATE = textwrap.dedent("""\
Expand Down Expand Up @@ -208,6 +209,57 @@ def test_get_llm_backed_user_simulator_prompt_with_persona(self, mocker):
test stop""").strip()
assert prompt == expected_prompt

def test_get_llm_backed_user_simulator_prompt_renders_persona_templates_in_sandbox(
self,
):
user_persona = UserPersona(
id="test_persona",
description="Test persona description",
behaviors=[
UserBehavior(
name="Behavior {{ stop_signal }}",
description="Description {{ stop_signal }}",
behavior_instructions=["instruction {{ stop_signal }}"],
violation_rubrics=["rubric 1"],
)
],
)

prompt = get_llm_backed_user_simulator_prompt(
conversation_plan="test plan",
conversation_history="test history",
stop_signal="test stop",
user_persona=user_persona,
)

assert "## Behavior test stop" in prompt
assert "Description test stop" in prompt
assert " * instruction test stop" in prompt

def test_get_llm_backed_user_simulator_prompt_blocks_unsafe_persona_templates(
self,
):
user_persona = UserPersona(
id="test_persona",
description="Test persona description",
behaviors=[
UserBehavior(
name="{{ ''.__class__.__mro__ }}",
description="Test behavior description",
behavior_instructions=["instruction 1"],
violation_rubrics=["rubric 1"],
)
],
)

with pytest.raises(SecurityError):
get_llm_backed_user_simulator_prompt(
conversation_plan="test plan",
conversation_history="test history",
stop_signal="test stop",
user_persona=user_persona,
)


class TestIsValidUserSimulatorTemplate:
"""Test cases for is_valid_user_simulator_template."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from google.adk.evaluation.simulation.per_turn_user_simulator_quality_prompts import get_per_turn_user_simulator_quality_prompt
from google.adk.evaluation.simulation.user_simulator_personas import UserBehavior
from google.adk.evaluation.simulation.user_simulator_personas import UserPersona
from jinja2.exceptions import SecurityError
import pytest

_MOCK_DEFAULT_TEMPLATE = textwrap.dedent("""\
Default template
Expand Down Expand Up @@ -182,3 +184,56 @@ def test_get_per_turn_user_simulator_quality_prompt_with_persona(
# Stop signal
stop""").strip()
assert prompt == expected_prompt

def test_get_per_turn_user_simulator_quality_prompt_renders_persona_templates_in_sandbox(
self,
):
persona = UserPersona(
id="test_persona",
description="Test persona description.",
behaviors=[
UserBehavior(
name="criteria {{ stop_signal }}",
description="Test behavior {{ stop_signal }}.",
behavior_instructions=["instruction1"],
violation_rubrics=["violation {{ stop_signal }}"],
)
],
)

prompt = get_per_turn_user_simulator_quality_prompt(
conversation_plan="plan",
conversation_history="history",
generated_user_response="response",
stop_signal="stop",
user_persona=persona,
)

assert "## Criteria: criteria stop" in prompt
assert "Test behavior stop." in prompt
assert " * violation stop" in prompt

def test_get_per_turn_user_simulator_quality_prompt_blocks_unsafe_persona_templates(
self,
):
persona = UserPersona(
id="test_persona",
description="Test persona description.",
behaviors=[
UserBehavior(
name="{{ ''.__class__.__mro__ }}",
description="Test behavior description.",
behavior_instructions=["instruction1"],
violation_rubrics=["violation1"],
)
],
)

with pytest.raises(SecurityError):
get_per_turn_user_simulator_quality_prompt(
conversation_plan="plan",
conversation_history="history",
generated_user_response="response",
stop_signal="stop",
user_persona=persona,
)