Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 59 additions & 47 deletions sentry_sdk/integrations/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.consts import SPANDATA
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
)

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -40,32 +43,26 @@


COLLECTED_CHAT_PARAMS = {
"model": SPANDATA.AI_MODEL_ID,
"k": SPANDATA.AI_TOP_K,
"p": SPANDATA.AI_TOP_P,
"seed": SPANDATA.AI_SEED,
"frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY,
"presence_penalty": SPANDATA.AI_PRESENCE_PENALTY,
"raw_prompting": SPANDATA.AI_RAW_PROMPTING,
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
"k": SPANDATA.GEN_AI_REQUEST_TOP_K,
"p": SPANDATA.GEN_AI_REQUEST_TOP_P,
"seed": SPANDATA.GEN_AI_REQUEST_SEED,
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
}

COLLECTED_PII_CHAT_PARAMS = {
"tools": SPANDATA.AI_TOOLS,
"preamble": SPANDATA.AI_PREAMBLE,
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
"preamble": SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
}

COLLECTED_CHAT_RESP_ATTRS = {
"generation_id": SPANDATA.AI_GENERATION_ID,
"is_search_required": SPANDATA.AI_SEARCH_REQUIRED,
"finish_reason": SPANDATA.AI_FINISH_REASON,
"generation_id": SPANDATA.GEN_AI_RESPONSE_ID,
"finish_reason": SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
}

COLLECTED_PII_CHAT_RESP_ATTRS = {
"citations": SPANDATA.AI_CITATIONS,
"documents": SPANDATA.AI_DOCUMENTS,
"search_queries": SPANDATA.AI_SEARCH_QUERIES,
"search_results": SPANDATA.AI_SEARCH_RESULTS,
"tool_calls": SPANDATA.AI_TOOL_CALLS,
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
}


Expand Down Expand Up @@ -102,16 +99,16 @@ def collect_chat_response_fields(
if hasattr(res, "text"):
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
SPANDATA.GEN_AI_RESPONSE_TEXT,
[res.text],
)
for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS:
if hasattr(res, pii_attr):
set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr))
for attr, spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS.items():
if hasattr(res, attr):
set_data_normalized(span, spandata_key, getattr(res, attr))

for attr in COLLECTED_CHAT_RESP_ATTRS:
for attr, spandata_key in COLLECTED_CHAT_RESP_ATTRS.items():
if hasattr(res, attr):
set_data_normalized(span, "ai." + attr, getattr(res, attr))
set_data_normalized(span, spandata_key, getattr(res, attr))

if hasattr(res, "meta"):
if hasattr(res.meta, "billed_units"):
Expand All @@ -127,9 +124,6 @@ def collect_chat_response_fields(
output_tokens=res.meta.tokens.output_tokens,
)

if hasattr(res.meta, "warnings"):
set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings)

@wraps(f)
def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
Expand All @@ -142,10 +136,11 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
return f(*args, **kwargs)

message = kwargs.get("message")
model = kwargs.get("model", "")

span = sentry_sdk.start_span(
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
name="cohere.client.Chat",
op=OP.GEN_AI_CHAT,
name=f"chat {model}".strip(),
origin=CohereIntegration.origin,
)
span.__enter__()
Expand All @@ -159,20 +154,28 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
reraise(*exc_info)

with capture_internal_exceptions():
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")

if should_send_default_pii() and integration.include_prompts:
messages = []
for x in kwargs.get("chat_history", []):
role = getattr(x, "role", "").lower()
if role == "chatbot":
role = "assistant"
messages.append(
{
"role": role,
"content": getattr(x, "message", ""),
}
)
messages.append({"role": "user", "content": message})
messages = normalize_message_roles(messages)
set_data_normalized(
span,
SPANDATA.AI_INPUT_MESSAGES,
list(
map(
lambda x: {
"role": getattr(x, "role", "").lower(),
"content": getattr(x, "message", ""),
},
kwargs.get("chat_history", []),
)
)
+ [{"role": "user", "content": message}],
SPANDATA.GEN_AI_REQUEST_MESSAGES,
messages,
unpack=False,
)
for k, v in COLLECTED_PII_CHAT_PARAMS.items():
if k in kwargs:
Expand All @@ -181,7 +184,7 @@ def new_chat(*args: "Any", **kwargs: "Any") -> "Any":
for k, v in COLLECTED_CHAT_PARAMS.items():
if k in kwargs:
set_data_normalized(span, v, kwargs[k])
set_data_normalized(span, SPANDATA.AI_STREAMING, False)
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, False)

if streaming:
old_iterator = res
Expand Down Expand Up @@ -226,27 +229,36 @@ def new_embed(*args: "Any", **kwargs: "Any") -> "Any":
if integration is None:
return f(*args, **kwargs)

model = kwargs.get("model", "")

with sentry_sdk.start_span(
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
name="Cohere Embedding Creation",
op=OP.GEN_AI_EMBEDDINGS,
name=f"embeddings {model}".strip(),
origin=CohereIntegration.origin,
) as span:
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")

if "texts" in kwargs and (
should_send_default_pii() and integration.include_prompts
):
if isinstance(kwargs["texts"], str):
set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]])
set_data_normalized(
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, [kwargs["texts"]]
)
elif (
isinstance(kwargs["texts"], list)
and len(kwargs["texts"]) > 0
and isinstance(kwargs["texts"][0], str)
):
set_data_normalized(
span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"]
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, kwargs["texts"]
)

if "model" in kwargs:
set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"])
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"]
)
try:
res = f(*args, **kwargs)
except Exception as e:
Expand Down
40 changes: 23 additions & 17 deletions tests/integrations/cohere/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,24 @@ def test_nonstreaming_chat(
tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "ai.chat_completions.create.cohere"
assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model"
assert span["op"] == "gen_ai.chat"
assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere"
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"

if send_default_pii and include_prompts:
assert (
'{"role": "system", "content": "some context"}'
in span["data"][SPANDATA.AI_INPUT_MESSAGES]
in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
)
assert (
'{"role": "user", "content": "hello"}'
in span["data"][SPANDATA.AI_INPUT_MESSAGES]
in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
)
assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
else:
assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
assert SPANDATA.AI_RESPONSES not in span["data"]
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"]["gen_ai.usage.output_tokens"] == 10
assert span["data"]["gen_ai.usage.input_tokens"] == 20
Expand Down Expand Up @@ -130,22 +132,24 @@ def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_p
tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "ai.chat_completions.create.cohere"
assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model"
assert span["op"] == "gen_ai.chat"
assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere"
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"

if send_default_pii and include_prompts:
assert (
'{"role": "system", "content": "some context"}'
in span["data"][SPANDATA.AI_INPUT_MESSAGES]
in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
)
assert (
'{"role": "user", "content": "hello"}'
in span["data"][SPANDATA.AI_INPUT_MESSAGES]
in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
)
assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
else:
assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
assert SPANDATA.AI_RESPONSES not in span["data"]
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]

assert span["data"]["gen_ai.usage.output_tokens"] == 10
assert span["data"]["gen_ai.usage.input_tokens"] == 20
Expand Down Expand Up @@ -224,11 +228,13 @@ def test_embed(sentry_init, capture_events, send_default_pii, include_prompts):
tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "ai.embeddings.create.cohere"
assert span["op"] == "gen_ai.embeddings"
assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere"
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES]
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
else:
assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]

assert span["data"]["gen_ai.usage.input_tokens"] == 10
assert span["data"]["gen_ai.usage.total_tokens"] == 10
Expand Down
Loading