Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import atexit
import base64
import concurrent.futures
Expand Down Expand Up @@ -2040,6 +2041,37 @@ def compute_metadata():
)


async def load_prompt_async(
project: str | None = None,
slug: str | None = None,
version: str | int | None = None,
project_id: str | None = None,
id: str | None = None,
defaults: Mapping[str, Any] | None = None,
no_trace: bool = False,
environment: str | None = None,
app_url: str | None = None,
api_key: str | None = None,
org_name: str | None = None,
) -> "Prompt":
"""Asynchronously loads a prompt from the specified project."""
prompt = load_prompt(
project=project,
slug=slug,
version=version,
project_id=project_id,
id=id,
defaults=defaults,
no_trace=no_trace,
environment=environment,
app_url=app_url,
api_key=api_key,
org_name=org_name,
)
await asyncio.to_thread(lambda: prompt.name)
return prompt


def _is_parameters_ref(value: Any) -> bool:
return isinstance(value, dict) and isinstance(value.get("id"), str)

Expand Down
68 changes: 68 additions & 0 deletions py/src/braintrust/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import threading
import time
from collections.abc import AsyncGenerator
from unittest import TestCase
Expand Down Expand Up @@ -318,6 +319,73 @@ def test_submit_logs_request_413_skips_retries(self) -> None:
)


def _prompt_response(slug: str):
return {
"objects": [
{
"id": f"prompt-{slug}",
"project_id": "project-123",
"name": "Saved prompt",
"slug": slug,
"_xact_id": "v1",
"description": None,
"tags": None,
"prompt_data": {
"prompt": {
"type": "chat",
"messages": [{"role": "user", "content": "Hello {{name}}"}],
},
"options": {"model": "gpt-5-mini"},
},
}
]
}


@pytest.mark.asyncio
async def test_load_prompt_async_eagerly_fetches_prompt(with_simulate_login):
mock_api_conn = MagicMock()
mock_api_conn.get_json.return_value = _prompt_response("saved-prompt")

with patch.object(logger._state, "api_conn", return_value=mock_api_conn):
prompt = await braintrust.load_prompt_async(
project="test-project",
slug="saved-prompt",
)

# Unlike load_prompt(), load_prompt_async() resolves the prompt metadata before returning.
mock_api_conn.get_json.assert_called_once_with(
"/v1/prompt",
{
"project_name": "test-project",
"slug": "saved-prompt",
},
)
assert prompt.slug == "saved-prompt"
assert prompt.build(name="Ada")["messages"][0]["content"] == "Hello Ada"


@pytest.mark.asyncio
async def test_load_prompt_async_loads_prompts_in_parallel(with_simulate_login):
mock_api_conn = MagicMock()
barrier = threading.Barrier(2, timeout=1)

def get_json(_path, args):
barrier.wait()
return _prompt_response(args["slug"])

mock_api_conn.get_json.side_effect = get_json

with patch.object(logger._state, "api_conn", return_value=mock_api_conn):
prompt1, prompt2 = await asyncio.gather(
braintrust.load_prompt_async(project="test-project", slug="prompt-1"),
braintrust.load_prompt_async(project="test-project", slug="prompt-2"),
)

assert [prompt1.slug, prompt2.slug] == ["prompt-1", "prompt-2"]
assert mock_api_conn.get_json.call_count == 2


class TestLogger(TestCase):
def test_load_prompt_prefers_version_over_environment_for_project_slug(self):
mock_api_conn = MagicMock()
Expand Down
Loading