diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 214e814d3e..54c739bbda 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -203,6 +203,7 @@ add_trace_processor, agent_span, custom_span, + flush_traces, function_span, gen_span_id, gen_trace_id, @@ -451,6 +452,7 @@ def enable_verbose_stdout_logging(): "add_trace_processor", "agent_span", "custom_span", + "flush_traces", "function_span", "generation_span", "get_current_span", diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 9f5e4f7568..4f4ee4a88e 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -41,6 +41,7 @@ __all__ = [ "add_trace_processor", "agent_span", + "flush_traces", "custom_span", "function_span", "generation_span", @@ -108,3 +109,15 @@ def set_tracing_export_api_key(api_key: str) -> None: Set the OpenAI API key for the backend exporter. """ default_exporter().set_api_key(api_key) + + +def flush_traces() -> None: + """Force an immediate flush of all buffered traces and spans. + + Call this at the end of each task in long-running worker processes + (Celery, FastAPI background tasks, RQ, etc.) to ensure traces are + exported to the backend rather than remaining buffered indefinitely. + """ + provider = get_trace_provider() + if hasattr(provider, "force_flush"): + provider.force_flush() diff --git a/src/agents/tracing/provider.py b/src/agents/tracing/provider.py index 90ea85cbf0..835e78a978 100644 --- a/src/agents/tracing/provider.py +++ b/src/agents/tracing/provider.py @@ -188,9 +188,22 @@ def create_span( ) -> Span[TSpanData]: """Create a new span.""" - @abstractmethod + def force_flush(self) -> None: + """Force all registered processors to flush their buffers immediately. + + The default implementation is a no-op so that existing + ``TraceProvider`` subclasses continue to work without modification. + Override this in your provider if you need custom flush behaviour. + """ + pass + def shutdown(self) -> None: - """Clean up any resources used by the provider.""" + """Clean up any resources used by the provider. + + The default implementation is a no-op for the same backward- + compatibility reasons as :meth:`force_flush`. + """ + pass class DefaultTraceProvider(TraceProvider): @@ -365,6 +378,13 @@ def create_span( trace_metadata=trace_metadata, ) + def force_flush(self) -> None: + """Force all processors to flush their buffers immediately.""" + self._refresh_disabled_flag() + if self._disabled: + return + self._multi_processor.force_flush() + def shutdown(self) -> None: if self._disabled: return diff --git a/tests/test_trace_processor.py b/tests/test_trace_processor.py index ad061d7995..d937c03b95 100644 --- a/tests/test_trace_processor.py +++ b/tests/test_trace_processor.py @@ -835,3 +835,103 @@ def test_truncate_string_for_json_limit_handles_escape_heavy_input(): assert truncated.endswith(exporter._OPENAI_TRACING_STRING_TRUNCATION_SUFFIX) assert exporter._value_json_size_bytes(truncated) <= max_bytes exporter.close() + + +def test_flush_traces_calls_provider_force_flush(): + """Test that flush_traces() delegates to the global trace provider's force_flush().""" + from unittest.mock import MagicMock, patch + + mock_provider = MagicMock() + + with patch("agents.tracing.get_trace_provider", return_value=mock_provider): + from agents.tracing import flush_traces + + flush_traces() + + mock_provider.force_flush.assert_called_once() + + +def test_flush_traces_importable_from_agents(): + """Test that flush_traces is importable from the top-level agents package.""" + from agents import flush_traces + + assert callable(flush_traces) + + +def test_flush_traces_tolerates_provider_without_override(): + """Test that flush_traces() is safe with a TraceProvider that does not override force_flush.""" + from unittest.mock import patch + + from agents.tracing import flush_traces + from agents.tracing.provider import TraceProvider + + class MinimalProvider(TraceProvider): + """A provider that only implements the required abstract methods.""" + + def register_processor(self, processor): + pass + + def set_processors(self, processors): + pass + + def get_current_trace(self): + return None + + def get_current_span(self): + return None + + def set_disabled(self, disabled): + pass + + def time_iso(self): + return "" + + def gen_trace_id(self): + return "t" + + def gen_span_id(self): + return "s" + + def gen_group_id(self): + return "g" + + def create_trace( # type: ignore[override] + self, + name: str, + trace_id: str | None = None, + group_id: str | None = None, + metadata: dict[str, Any] | None = None, + disabled: bool = False, + tracing: Any = None, + ) -> Any: + raise NotImplementedError + + def create_span( # type: ignore[override] + self, + span_data: Any, + span_id: str | None = None, + parent: Any = None, + disabled: bool = False, + ) -> Any: + raise NotImplementedError + + provider = MinimalProvider() + with patch("agents.tracing.get_trace_provider", return_value=provider): + # Should not raise - force_flush has a default no-op implementation + flush_traces() + + +def test_force_flush_respects_disabled_flag(): + """Test that force_flush() skips processing when tracing is disabled.""" + from unittest.mock import MagicMock + + from agents.tracing.provider import DefaultTraceProvider + + provider = DefaultTraceProvider() + mock_processor = MagicMock() + provider.register_processor(mock_processor) + + provider.set_disabled(True) + provider.force_flush() + + mock_processor.force_flush.assert_not_called()