diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 10e8bc00..34aa9b74 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1,3 +1,4 @@ +import asyncio import atexit import base64 import builtins @@ -2047,6 +2048,37 @@ def compute_metadata(): ) +async def load_prompt_async( + project: str | None = None, + slug: str | None = None, + version: str | int | None = None, + project_id: str | None = None, + id: str | None = None, + defaults: Mapping[str, Any] | None = None, + no_trace: bool = False, + environment: str | None = None, + app_url: str | None = None, + api_key: str | None = None, + org_name: str | None = None, +) -> "Prompt": + """Asynchronously loads a prompt from the specified project.""" + prompt = load_prompt( + project=project, + slug=slug, + version=version, + project_id=project_id, + id=id, + defaults=defaults, + no_trace=no_trace, + environment=environment, + app_url=app_url, + api_key=api_key, + org_name=org_name, + ) + await asyncio.to_thread(lambda: prompt.name) + return prompt + + def _is_parameters_ref(value: Any) -> bool: return isinstance(value, dict) and isinstance(value.get("id"), str) diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 9c7cfe1b..422f8bc9 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -2,9 +2,11 @@ # pyright: reportPrivateUsage=false import asyncio import builtins +import inspect import json import logging import os +import threading import time from collections.abc import AsyncGenerator from unittest import TestCase @@ -319,6 +321,80 @@ def test_submit_logs_request_413_skips_retries(self) -> None: ) +def test_load_prompt_async_signature_matches_load_prompt(): + assert ( + inspect.signature(braintrust.load_prompt_async).parameters + == inspect.signature(braintrust.load_prompt).parameters + ) + + +def _prompt_response(slug: str): + return { + "objects": [ + { + "id": f"prompt-{slug}", + "project_id": "project-123", + "name": "Saved prompt", + "slug": slug, + "_xact_id": "v1", + "description": None, + "tags": None, + "prompt_data": { + "prompt": { + "type": "chat", + "messages": [{"role": "user", "content": "Hello {{name}}"}], + }, + "options": {"model": "gpt-5-mini"}, + }, + } + ] + } + + +@pytest.mark.asyncio +async def test_load_prompt_async_eagerly_fetches_prompt(with_simulate_login): + mock_api_conn = MagicMock() + mock_api_conn.get_json.return_value = _prompt_response("saved-prompt") + + with patch.object(logger._state, "api_conn", return_value=mock_api_conn): + prompt = await braintrust.load_prompt_async( + project="test-project", + slug="saved-prompt", + ) + + # Unlike load_prompt(), load_prompt_async() resolves the prompt metadata before returning. + mock_api_conn.get_json.assert_called_once_with( + "/v1/prompt", + { + "project_name": "test-project", + "slug": "saved-prompt", + }, + ) + assert prompt.slug == "saved-prompt" + assert prompt.build(name="Ada")["messages"][0]["content"] == "Hello Ada" + + +@pytest.mark.asyncio +async def test_load_prompt_async_loads_prompts_in_parallel(with_simulate_login): + mock_api_conn = MagicMock() + barrier = threading.Barrier(2, timeout=1) + + def get_json(_path, args): + barrier.wait() + return _prompt_response(args["slug"]) + + mock_api_conn.get_json.side_effect = get_json + + with patch.object(logger._state, "api_conn", return_value=mock_api_conn): + prompt1, prompt2 = await asyncio.gather( + braintrust.load_prompt_async(project="test-project", slug="prompt-1"), + braintrust.load_prompt_async(project="test-project", slug="prompt-2"), + ) + + assert [prompt1.slug, prompt2.slug] == ["prompt-1", "prompt-2"] + assert mock_api_conn.get_json.call_count == 2 + + class TestLogger(TestCase): def test_load_prompt_prefers_version_over_environment_for_project_slug(self): mock_api_conn = MagicMock()