Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 28 additions & 48 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk.ai.monitoring import set_ai_pipeline_name
from sentry_sdk.ai.utils import (
GEN_AI_ALLOWED_MESSAGE_ROLES,
get_start_span_function,
Expand Down Expand Up @@ -241,38 +240,28 @@ def setup_once() -> None:
_patch_embeddings_provider(OllamaEmbeddings)


class WatchedSpan:
span: "Span" = None # type: ignore[assignment]
children: "List[WatchedSpan]" = []
is_pipeline: bool = False

def __init__(self, span: "Span") -> None:
self.span = span


class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
"""Callback handler that creates Sentry spans."""

def __init__(
self, max_span_map_size: "Optional[int]", include_prompts: bool
) -> None:
self.span_map: "OrderedDict[UUID, WatchedSpan]" = OrderedDict()
self.span_map: "OrderedDict[UUID, sentry_sdk.tracing.Span]" = OrderedDict()
self.max_span_map_size = max_span_map_size
self.include_prompts = include_prompts

def gc_span_map(self) -> None:
if self.max_span_map_size is not None:
while len(self.span_map) > self.max_span_map_size:
run_id, watched_span = self.span_map.popitem(last=False)
self._exit_span(watched_span, run_id)
run_id, span = self.span_map.popitem(last=False)
self._exit_span(span, run_id)

def _handle_error(self, run_id: "UUID", error: "Any") -> None:
with capture_internal_exceptions():
if not run_id or run_id not in self.span_map:
return

span_data = self.span_map[run_id]
span = span_data.span
span = self.span_map[run_id]

sentry_sdk.capture_exception(error, span.scope)

Expand All @@ -291,29 +280,27 @@ def _create_span(
run_id: "UUID",
parent_id: "Optional[Any]",
**kwargs: "Any",
) -> "WatchedSpan":
watched_span: "Optional[WatchedSpan]" = None
) -> "sentry_sdk.tracing.Span":
span = None
if parent_id:
parent_span: "Optional[WatchedSpan]" = self.span_map.get(parent_id)
parent_span: "Optional[sentry_sdk.tracing.Span]" = self.span_map.get(
parent_id
)
if parent_span:
watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
parent_span.children.append(watched_span)
span = parent_span.start_child(**kwargs)

if watched_span is None:
watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
if span is None:
span = sentry_sdk.start_span(**kwargs)

watched_span.span.__enter__()
self.span_map[run_id] = watched_span
span.__enter__()
self.span_map[run_id] = span
self.gc_span_map()
return watched_span
return span

def _exit_span(
self: "SentryLangchainCallback", span_data: "WatchedSpan", run_id: "UUID"
self: "SentryLangchainCallback", span: "sentry_sdk.tracing.Span", run_id: "UUID"
) -> None:
if span_data.is_pipeline:
set_ai_pipeline_name(None)

span_data.span.__exit__(None, None, None)
span.__exit__(None, None, None)
del self.span_map[run_id]

def on_llm_start(
Expand Down Expand Up @@ -341,14 +328,13 @@ def on_llm_start(
or ""
)

watched_span = self._create_span(
span = self._create_span(
run_id,
parent_run_id,
op=OP.GEN_AI_TEXT_COMPLETION,
name=f"text_completion {model}".strip(),
origin=LangchainIntegration.origin,
)
span = watched_span.span

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "text_completion")

Expand Down Expand Up @@ -421,14 +407,13 @@ def on_chat_model_start(
or ""
)

watched_span = self._create_span(
span = self._create_span(
run_id,
kwargs.get("parent_run_id"),
op=OP.GEN_AI_CHAT,
name=f"chat {model}".strip(),
origin=LangchainIntegration.origin,
)
span = watched_span.span

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
if model:
Expand Down Expand Up @@ -505,8 +490,7 @@ def on_chat_model_end(
if not run_id or run_id not in self.span_map:
return

span_data = self.span_map[run_id]
span = span_data.span
span = self.span_map[run_id]

if should_send_default_pii() and self.include_prompts:
set_data_normalized(
Expand All @@ -516,7 +500,7 @@ def on_chat_model_end(
)

_record_token_usage(span, response)
self._exit_span(span_data, run_id)
self._exit_span(span, run_id)

def on_llm_end(
self: "SentryLangchainCallback",
Expand All @@ -530,8 +514,7 @@ def on_llm_end(
if not run_id or run_id not in self.span_map:
return

span_data = self.span_map[run_id]
span = span_data.span
span = self.span_map[run_id]

try:
generation = response.generations[0][0]
Expand Down Expand Up @@ -579,7 +562,7 @@ def on_llm_end(
)

_record_token_usage(span, response)
self._exit_span(span_data, run_id)
self._exit_span(span, run_id)
Comment thread
alexander-alderman-webb marked this conversation as resolved.

def on_llm_error(
self: "SentryLangchainCallback",
Expand Down Expand Up @@ -612,15 +595,14 @@ def on_agent_finish(
if not run_id or run_id not in self.span_map:
return

span_data = self.span_map[run_id]
span = span_data.span
span = self.span_map[run_id]

if should_send_default_pii() and self.include_prompts:
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, finish.return_values.items()
)

self._exit_span(span_data, run_id)
self._exit_span(span, run_id)

def on_tool_start(
self: "SentryLangchainCallback",
Expand All @@ -637,14 +619,13 @@ def on_tool_start(

tool_name = serialized.get("name") or kwargs.get("name") or ""

watched_span = self._create_span(
span = self._create_span(
run_id,
kwargs.get("parent_run_id"),
op=OP.GEN_AI_EXECUTE_TOOL,
name=f"execute_tool {tool_name}".strip(),
origin=LangchainIntegration.origin,
)
span = watched_span.span

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
Expand Down Expand Up @@ -681,13 +662,12 @@ def on_tool_end(
if not run_id or run_id not in self.span_map:
return

span_data = self.span_map[run_id]
span = span_data.span
span = self.span_map[run_id]

if should_send_default_pii() and self.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_TOOL_OUTPUT, output)

self._exit_span(span_data, run_id)
self._exit_span(span, run_id)

def on_tool_error(
self,
Expand Down
Loading