From 6db5eaca2e1d13e146cc06ce0d60e0fb6694f18f Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 29 Jun 2026 11:32:01 -0400 Subject: [PATCH] feat(prompts): add async prompt loading Add an awaitable prompt-loading API that mirrors `load_prompt()` while resolving prompt metadata off the event loop: ```python prompt = await braintrust.load_prompt_async(project="My Project", slug="my-prompt") kwargs = prompt.build(name="Ada") ``` The API can be used with `asyncio.gather()` to load multiple prompts without blocking the event loop: ```python prompt_a, prompt_b = await asyncio.gather( braintrust.load_prompt_async(project="My Project", slug="prompt-a"), braintrust.load_prompt_async(project="My Project", slug="prompt-b"), ) ``` The async API reuses the existing sync `load_prompt()` implementation so cache and fallback behavior stay consistent, and uses `asyncio.to_thread()` to avoid blocking callers' event loops during the synchronous HTTP/cache lookup. Tests cover eager metadata loading and parallel `load_prompt_async()` usage. --- py/src/braintrust/logger.py | 32 +++++++++++++++ py/src/braintrust/test_logger.py | 68 ++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index c6f46802..9d07d007 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1,3 +1,4 @@ +import asyncio import atexit import base64 import concurrent.futures @@ -2040,6 +2041,37 @@ def compute_metadata(): ) +async def load_prompt_async( + project: str | None = None, + slug: str | None = None, + version: str | int | None = None, + project_id: str | None = None, + id: str | None = None, + defaults: Mapping[str, Any] | None = None, + no_trace: bool = False, + environment: str | None = None, + app_url: str | None = None, + api_key: str | None = None, + org_name: str | None = None, +) -> "Prompt": + """Asynchronously loads a prompt from the specified project.""" + prompt = load_prompt( + project=project, + slug=slug, + version=version, + project_id=project_id, + id=id, + defaults=defaults, + no_trace=no_trace, + environment=environment, + app_url=app_url, + api_key=api_key, + org_name=org_name, + ) + await asyncio.to_thread(lambda: prompt.name) + return prompt + + def _is_parameters_ref(value: Any) -> bool: return isinstance(value, dict) and isinstance(value.get("id"), str) diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 1f3149e5..b6aba469 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -4,6 +4,7 @@ import json import logging import os +import threading import time from collections.abc import AsyncGenerator from unittest import TestCase @@ -318,6 +319,73 @@ def test_submit_logs_request_413_skips_retries(self) -> None: ) +def _prompt_response(slug: str): + return { + "objects": [ + { + "id": f"prompt-{slug}", + "project_id": "project-123", + "name": "Saved prompt", + "slug": slug, + "_xact_id": "v1", + "description": None, + "tags": None, + "prompt_data": { + "prompt": { + "type": "chat", + "messages": [{"role": "user", "content": "Hello {{name}}"}], + }, + "options": {"model": "gpt-5-mini"}, + }, + } + ] + } + + +@pytest.mark.asyncio +async def test_load_prompt_async_eagerly_fetches_prompt(with_simulate_login): + mock_api_conn = MagicMock() + mock_api_conn.get_json.return_value = _prompt_response("saved-prompt") + + with patch.object(logger._state, "api_conn", return_value=mock_api_conn): + prompt = await braintrust.load_prompt_async( + project="test-project", + slug="saved-prompt", + ) + + # Unlike load_prompt(), load_prompt_async() resolves the prompt metadata before returning. + mock_api_conn.get_json.assert_called_once_with( + "/v1/prompt", + { + "project_name": "test-project", + "slug": "saved-prompt", + }, + ) + assert prompt.slug == "saved-prompt" + assert prompt.build(name="Ada")["messages"][0]["content"] == "Hello Ada" + + +@pytest.mark.asyncio +async def test_load_prompt_async_loads_prompts_in_parallel(with_simulate_login): + mock_api_conn = MagicMock() + barrier = threading.Barrier(2, timeout=1) + + def get_json(_path, args): + barrier.wait() + return _prompt_response(args["slug"]) + + mock_api_conn.get_json.side_effect = get_json + + with patch.object(logger._state, "api_conn", return_value=mock_api_conn): + prompt1, prompt2 = await asyncio.gather( + braintrust.load_prompt_async(project="test-project", slug="prompt-1"), + braintrust.load_prompt_async(project="test-project", slug="prompt-2"), + ) + + assert [prompt1.slug, prompt2.slug] == ["prompt-1", "prompt-2"] + assert mock_api_conn.get_json.call_count == 2 + + class TestLogger(TestCase): def test_load_prompt_prefers_version_over_environment_for_project_slug(self): mock_api_conn = MagicMock()