Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
GEN_AI_CLIENT_SECRET = os.environ.get("GEN_AI_CLIENT_SECRET") or None
GEN_AI_ACCOUNT_ID = os.environ.get("GEN_AI_ACCOUNT_ID") or None
GEN_AI_TENANT_ID = os.environ.get("GEN_AI_TENANT_ID") or None
GEN_AI_VENDOR = os.environ.get("GEN_AI_VENDOR") or "openai"
GEN_AI_MODEL_NAME = os.environ.get("GEN_AI_MODEL_NAME") or "gpt-4o-2024-11-20"
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily
Expand Down
11 changes: 10 additions & 1 deletion backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.override_models import LLMOverride
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
Expand Down Expand Up @@ -572,7 +573,15 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm, _ = get_llms_for_persona(persona)
llm_override = None
if channel_config and channel_config.channel_config:
ch_vendor = channel_config.channel_config.get("llm_vendor")
ch_model = channel_config.channel_config.get("llm_model_name")
if ch_vendor or ch_model:
llm_override = LLMOverride(
model_provider=ch_vendor, model_version=ch_model
)
llm, _ = get_llms_for_persona(persona, llm_override=llm_override)

# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,9 @@ class ChannelConfig(TypedDict):
jira_config: NotRequired[dict[str, Any]] # Contains all JIRA related settings
# Curated response config if user asks for more help
curated_response_config: NotRequired[dict[str, Any]]
# LLM configuration for this channel
llm_vendor: NotRequired[str]
llm_model_name: NotRequired[str]


class SlackBotResponseType(str, PyEnum):
Expand Down
95 changes: 62 additions & 33 deletions backend/danswer/llm/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from danswer.configs.model_configs import GEN_AI_CLIENT_SECRET
from danswer.configs.model_configs import GEN_AI_IDENTITY_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_NAME
from danswer.configs.model_configs import GEN_AI_TENANT_ID
from danswer.configs.model_configs import GEN_AI_VENDOR
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.llm.interfaces import ToolChoiceOptions
Expand Down Expand Up @@ -67,20 +69,23 @@ def __init__(
# Not used here but you probably want a model server that isn't completely open
api_key: str | None,
timeout: int,
endpoint: str
| None = "https://alpha.uipath.com/{account_id}/{tenant_id}/llmgateway_/api/raw/vendor/openai/model/gpt-4o-2024-11-20/completions",
endpoint: str | None = None,
identity_url: str | None = GEN_AI_IDENTITY_ENDPOINT,
client_id: str | None = GEN_AI_CLIENT_ID,
client_secret: str | None = GEN_AI_CLIENT_SECRET,
account_id: str | None = GEN_AI_ACCOUNT_ID,
tenant_id: str | None = GEN_AI_TENANT_ID,
max_output_tokens: int = int(GEN_AI_MAX_OUTPUT_TOKENS),
api_version: str | None = GEN_AI_API_VERSION,
llm_vendor: str | None = None,
llm_model_name: str | None = None,
):
vendor = llm_vendor or GEN_AI_VENDOR
model = llm_model_name or GEN_AI_MODEL_NAME
if not endpoint:
raise ValueError(
"Cannot point Danswer to a custom LLM server without providing the "
"endpoint for the model server."
endpoint = (
"https://alpha.uipath.com/{account_id}/{tenant_id}"
f"/llmgateway_/api/raw/vendor/{vendor}/model/{model}/completions"
)

if not identity_url:
Expand Down Expand Up @@ -129,52 +134,71 @@ def __init__(
self._max_output_tokens = max_output_tokens
self._timeout = timeout
self.token = self._get_token()
# TODO: Remove hard-coding
self._model_provider = "custom"
self._model_version = "gpt-4"
self._model_provider = vendor
self._model_version = model
self._temperature = 0.0
self._api_key = api_key

if max_output_tokens <= 0:
self._max_output_tokens = 7000

def _execute(self, input: LanguageModelInput) -> AIMessage:
is_bedrock = self._model_provider == "awsbedrock"
api_flavor = "converse" if is_bedrock else "chat-completions"

headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.token,
"X-UiPath-LlmGateway-RequestingProduct": "darwin",
"X-UiPath-LlmGateway-RequestingFeature": "ChatWithAssistant",
"X-UiPath-LlmGateway-ApiFlavor": "chat-completions",
"X-UiPath-LlmGateway-ApiFlavor": api_flavor,
"X-UiPath-LlmGateway-ApiVersion": "2024-10-21",
"X-UiPath-LlmGateway-TimeoutSeconds": "60",
"X-UIPATH-STREAMING-ENABLED": "false",
}

# print(f"Input: {input}")
chatPrompt = convert_lm_input_to_prompt(input)

json_array = []
messages = chatPrompt.to_messages()
for msg in messages:
mapped_type = self._map_type(msg.type)
json_obj = {
"role": mapped_type,
"content": self._clean_json_string(msg.content),
}
json_array.append(json_obj)

data = {"max_tokens": self._max_output_tokens, "messages": json_array}
if is_bedrock:
# AWS Bedrock Converse API format
bedrock_messages = []
for msg in messages:
mapped_type = self._map_type(msg.type)
if mapped_type == "system":
continue # system handled separately below
bedrock_messages.append(
{
"role": mapped_type,
"content": [{"text": self._clean_json_string(msg.content)}],
}
)
data: dict = {
"messages": bedrock_messages,
"inferenceConfig": {"maxTokens": self._max_output_tokens},
}
system_msgs = [
{"text": self._clean_json_string(msg.content)}
for msg in messages
if self._map_type(msg.type) == "system"
]
if system_msgs:
data["system"] = system_msgs
else:
# OpenAI chat-completions format
json_array = []
for msg in messages:
mapped_type = self._map_type(msg.type)
json_array.append(
{
"role": mapped_type,
"content": self._clean_json_string(msg.content),
}
)
data = {"max_tokens": self._max_output_tokens, "messages": json_array}

try:
print(data)
with open("requestdata.json", "w") as fp:
json.dump(data, fp)

# json_str = json.dumps(data, ensure_ascii=False, indent=4)
# print(f"Request Data: {json_str}")
# json_data = json.loads(json_str)
response = requests.post(
# self._endpoint, headers=headers, data=json_str, timeout=self._timeout
self._endpoint,
headers=headers,
json=data,
Expand All @@ -185,16 +209,21 @@ def _execute(self, input: LanguageModelInput) -> AIMessage:

response.raise_for_status()
try:
data = json.loads(response.content)
print(data)
response_data = json.loads(response.content)
except json.decoder.JSONDecodeError as e:
print("Failed to parse JSON:", response.content)
raise e

message_content = "No response from LLM server"
if data["choices"]:
message_content = data["choices"][0]["message"]["content"]
# print(message_content)
if is_bedrock:
output = (
response_data.get("output", {}).get("message", {}).get("content", [])
)
if output:
message_content = output[0].get("text", message_content)
else:
if response_data.get("choices"):
message_content = response_data["choices"][0]["message"]["content"]
return AIMessage(content=message_content)

def _clean_json_string(self, input_string):
Expand Down
85 changes: 15 additions & 70 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_NAME
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import fetch_provider
from danswer.configs.model_configs import GEN_AI_VENDOR
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.llm.custom_llm import CustomModelServer


def get_main_llm_from_tuple(
Expand All @@ -26,40 +23,13 @@ def get_llms_for_persona(
) -> tuple[LLM, LLM]:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None

provider_name = model_provider_override or persona.llm_model_provider_override
if not provider_name:
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
)

with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)

if not llm_provider:
raise ValueError("No LLM provider found")

model = model_version_override or persona.llm_model_version_override
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
if not model:
raise ValueError("No model name found")
if not fast_model:
raise ValueError("No fast model name found")
llm_override.temperature if llm_override else None

def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
additional_headers=additional_headers,
)
vendor = model_provider_override or GEN_AI_VENDOR
model = model_version_override or GEN_AI_MODEL_NAME

return _create_llm(model), _create_llm(fast_model)
llm = get_llm(provider=vendor, model=model)
return llm, llm


def get_default_llms(
Expand All @@ -70,35 +40,8 @@ def get_default_llms(
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()

with get_session_context_manager() as db_session:
llm_provider = fetch_default_provider(db_session)

if not llm_provider:
raise ValueError("No default LLM provider found")

model_name = llm_provider.default_model_name
fast_model_name = (
llm_provider.fast_default_model_name or llm_provider.default_model_name
)
if not model_name:
raise ValueError("No default model name found")
if not fast_model_name:
raise ValueError("No fast default model name found")

def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)

return _create_llm(model_name), _create_llm(fast_model_name)
llm = get_llm(provider=GEN_AI_VENDOR, model=GEN_AI_MODEL_NAME, timeout=timeout)
return llm, llm


def get_llm(
Expand All @@ -113,6 +56,8 @@ def get_llm(
additional_headers: dict[str, str] | None = None,
) -> LLM:
return CustomModelServer(
timeout=timeout,
api_key=api_key,
)
timeout=timeout,
api_key=api_key,
llm_vendor=provider,
llm_model_name=model,
)
2 changes: 2 additions & 0 deletions backend/danswer/server/manage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class SlackBotConfigCreationRequest(BaseModel):
jira_title_filter: list[str] | None = None
curated_response_user_title_filter: list[str] | None = None
response_type: SlackBotResponseType
llm_vendor: str | None = None
llm_model_name: str | None = None

@validator("answer_filters", pre=True)
def validate_filters(cls, value: list[str]) -> list[str]:
Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/server/manage/slack_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def _form_channel_config(
] = curated_response_user_title_filter
if curated_response_config:
channel_config["curated_response_config"] = curated_response_config
if slack_bot_config_creation_request.llm_vendor:
channel_config["llm_vendor"] = slack_bot_config_creation_request.llm_vendor
if slack_bot_config_creation_request.llm_model_name:
channel_config[
"llm_model_name"
] = slack_bot_config_creation_request.llm_model_name

channel_config[
"respond_to_bots"
Expand Down
4 changes: 4 additions & 0 deletions deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ services:
- GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-}
- GEN_AI_ACCOUNT_ID=${GEN_AI_ACCOUNT_ID:-}
- GEN_AI_TENANT_ID=${GEN_AI_TENANT_ID:-}
- GEN_AI_VENDOR=${GEN_AI_VENDOR:-openai}
- GEN_AI_MODEL_NAME=${GEN_AI_MODEL_NAME:-gpt-4o-2024-11-20}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
Expand Down Expand Up @@ -133,6 +135,8 @@ services:
- GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-}
- GEN_AI_ACCOUNT_ID=${GEN_AI_ACCOUNT_ID:-}
- GEN_AI_TENANT_ID=${GEN_AI_TENANT_ID:-}
- GEN_AI_VENDOR=${GEN_AI_VENDOR:-openai}
- GEN_AI_MODEL_NAME=${GEN_AI_MODEL_NAME:-gpt-4o-2024-11-20}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
Expand Down
4 changes: 4 additions & 0 deletions deployment/docker_compose/docker-compose.local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ services:
- GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-}
- GEN_AI_ACCOUNT_ID=${GEN_AI_ACCOUNT_ID:-}
- GEN_AI_TENANT_ID=${GEN_AI_TENANT_ID:-}
- GEN_AI_VENDOR=${GEN_AI_VENDOR:-openai}
- GEN_AI_MODEL_NAME=${GEN_AI_MODEL_NAME:-gpt-4o-2024-11-20}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
Expand Down Expand Up @@ -136,6 +138,8 @@ services:
- GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-}
- GEN_AI_ACCOUNT_ID=${GEN_AI_ACCOUNT_ID:-}
- GEN_AI_TENANT_ID=${GEN_AI_TENANT_ID:-}
- GEN_AI_VENDOR=${GEN_AI_VENDOR:-openai}
- GEN_AI_MODEL_NAME=${GEN_AI_MODEL_NAME:-gpt-4o-2024-11-20}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
Expand Down
Loading
Loading