Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
from apps.db.db import exec_sql, get_version, check_connection
from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
from apps.system.crud.parameter_manage import get_groups
from apps.system.schemas.system_schema import AssistantOutDsSchema
Expand Down Expand Up @@ -176,11 +177,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
@classmethod
async def create(cls, *args, **kwargs):
specialized_model_id = None
_ai_model_list = []
if args[3]:
if args[1]:
ws_id = args[1].oid
_ai_model_list = get_ai_model_list_by_workspace(args[0], ws_id)
if args[3].enable_custom_model:
if args[3].custom_model:
specialized_model_id = args[3].custom_model
print("use custom model: id[" + args[3].custom_model + "]")
if any(str(model.id) == str(args[3].custom_model) for model in _ai_model_list):
specialized_model_id = args[3].custom_model
print("use custom model: id[" + specialized_model_id + "]")
config: LLMConfig = await get_default_config(specialized_model_id)
instance = cls(*args, **kwargs, config=config)

Expand Down
4 changes: 2 additions & 2 deletions backend/apps/system/api/aimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ async def update_model_ws_mapping_by_id(
return [str(ws_id) for ws_id in ws_ids]


@router.get("/list_by_ws", response_model=AiModelBrief, summary=f"{PLACEHOLDER_PREFIX}system_model_query",
@router.get("/list/by_ws", response_model=List[AiModelBrief], summary=f"{PLACEHOLDER_PREFIX}system_model_query",
description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def get_model_by_ws(
session: SessionDep,
current_user: CurrentUser
):
return get_ai_model_list_by_workspace(session, current_user.workspace_id)
return get_ai_model_list_by_workspace(session, current_user.oid)
68 changes: 42 additions & 26 deletions backend/apps/system/models/system_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

from typing import Optional

from pydantic import field_serializer
from sqlmodel import BigInteger, Field, Text, SQLModel

from common.core.models import SnowflakeBase
from common.core.schemas import BaseCreatorDTO

Expand All @@ -9,84 +11,98 @@ class AiModelBase:
supplier: int = Field(nullable=False)
name: str = Field(max_length=255, nullable=False)
model_type: int = Field(nullable=False)
base_model: str = Field(max_length = 255, nullable=False)
base_model: str = Field(max_length=255, nullable=False)
default_model: bool = Field(default=False, nullable=False)


class AiModelDetail(SnowflakeBase, AiModelBase, table=True):
__tablename__ = "ai_model"
api_key: str | None = Field(default=None, nullable=True, sa_type=Text())
api_domain: str = Field(nullable=False, sa_type=Text())
protocol: int = Field(nullable=False, default = 1)
config: str = Field(sa_type = Text())
status: int = Field(nullable=False, default = 1)
create_time: int = Field(default=0, sa_type=BigInteger())

__tablename__ = "ai_model"
api_key: str | None = Field(default=None, nullable=True, sa_type=Text())
api_domain: str = Field(nullable=False, sa_type=Text())
protocol: int = Field(nullable=False, default=1)
config: str = Field(sa_type=Text())
status: int = Field(nullable=False, default=1)
create_time: int = Field(default=0, sa_type=BigInteger())


class AiModelWorkspaceMapping(SnowflakeBase, table=True):
__tablename__ = "ai_model_workspace_mapping"
ai_model_id: int = Field(default=None, nullable=True, sa_type=BigInteger())
workspace_id: int = Field(default=None, nullable=True, sa_type=BigInteger())


class AiModelBrief(SQLModel):
id: int
name: str
default_model: bool
supplier: int

@field_serializer("id")
def id_to_str(self, v: int) -> str:
return str(v)


class WorkspaceBase(SQLModel):
name: str = Field(max_length=255, nullable=False)


class WorkspaceEditor(WorkspaceBase, BaseCreatorDTO):
pass



class WorkspaceModel(SnowflakeBase, WorkspaceBase, table=True):
__tablename__ = "sys_workspace"
create_time: int = Field(default=0, sa_type=BigInteger())



class UserWsBaseModel(SQLModel):
uid: int = Field(nullable=False, sa_type=BigInteger())
oid: int = Field(nullable=False, sa_type=BigInteger())
weight: int = Field(default=0, nullable=False)

weight: int = Field(default=0, nullable=False)


class UserWsModel(SnowflakeBase, UserWsBaseModel, table=True):
__tablename__ = "sys_user_ws"


class AssistantBaseModel(SQLModel):
name: str = Field(max_length=255, nullable=False)
type: int = Field(nullable=False, default=0)
domain: str = Field(max_length=255, nullable=False)
description: Optional[str] = Field(sa_type = Text(), nullable=True)
configuration: Optional[str] = Field(sa_type = Text(), nullable=True)
description: Optional[str] = Field(sa_type=Text(), nullable=True)
configuration: Optional[str] = Field(sa_type=Text(), nullable=True)
create_time: int = Field(default=0, sa_type=BigInteger())
app_id: Optional[str] = Field(default=None, max_length=255, nullable=True)
app_id: Optional[str] = Field(default=None, max_length=255, nullable=True)
app_secret: Optional[str] = Field(default=None, max_length=255, nullable=True)
oid: Optional[int] = Field(nullable=True, sa_type=BigInteger(), default=1)
enable_custom_model: Optional[bool] = Field(default=False, nullable=True)
custom_model: Optional[str] = Field(default=None, max_length=255, nullable=True)


class AssistantModel(SnowflakeBase, AssistantBaseModel, table=True):
__tablename__ = "sys_assistant"


class AuthenticationBaseModel(SQLModel):
name: str = Field(max_length=255, nullable=False)
type: int = Field(nullable=False, default=0)
config: Optional[str] = Field(sa_type = Text(), nullable=True)
config: Optional[str] = Field(sa_type=Text(), nullable=True)


class AuthenticationModel(SnowflakeBase, AuthenticationBaseModel, table=True):
__tablename__ = "sys_authentication"
create_time: Optional[int] = Field(default=0, sa_type=BigInteger())
enable: bool = Field(default=False, nullable=False)
valid: bool = Field(default=False, nullable=False)


class ApiKeyBaseModel(SQLModel):
access_key: str = Field(max_length=255, nullable=False)
secret_key: str = Field(max_length=255, nullable=False)
create_time: int = Field(default=0, sa_type=BigInteger())
uid: int = Field(default=0,nullable=False, sa_type=BigInteger())
uid: int = Field(default=0, nullable=False, sa_type=BigInteger())
status: bool = Field(default=True, nullable=False)



class ApiKeyModel(SnowflakeBase, ApiKeyBaseModel, table=True):
__tablename__ = "sys_apikey"
__tablename__ = "sys_apikey"
1 change: 1 addition & 0 deletions frontend/src/api/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ export const modelApi = {
platform: (id: number, lazy?: number, pid?: string) =>
request.post(`/system/platform/org/${id}`, { lazy, pid }),
userSync: (data: any) => request.post(`/system/platform/user/sync`, data),
list_by_ws: () => request.get(`/system/aimodel/list/by_ws`),
}
44 changes: 22 additions & 22 deletions frontend/src/views/system/embedded/iframe.vue
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,23 @@ const dsListOptions = ref<any[]>([])
const embeddedListWithSearch = computed(() => {
if (!keywords.value) return embeddedList.value
return embeddedList.value.filter((ele: any) =>
ele.name.toLowerCase().includes(keywords.value.toLowerCase()),
ele.name.toLowerCase().includes(keywords.value.toLowerCase())
)
})

interface Model {
id: number
name: string
model_type: string
base_model: string
id: string
default_model: boolean
supplier: number
}

const modelList =ref<Array<Model>>([])
const modelList = ref<Array<Model>>([])

const searchModels = () => {
searchLoading.value = true
modelApi
.queryAll()
.list_by_ws()
.then((res: any) => {
modelList.value = res
})
Expand Down Expand Up @@ -307,8 +305,8 @@ const validateUrl = (_: any, value: any, callback: any) => {
if (value === '') {
callback(
new Error(
t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings'),
),
t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings')
)
)
} else {
// var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line
Expand Down Expand Up @@ -352,7 +350,7 @@ const dsRules = {
const validatePass = (_: any, value: any, callback: any) => {
if (value === '') {
callback(
new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url')),
new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url'))
)
} else {
// var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line
Expand Down Expand Up @@ -470,7 +468,7 @@ const saveEmbedded = () => {
if (!currentEmbedded.id) {
delete obj.id
}
if (obj.custom_model == undefined){
if (obj.custom_model == undefined) {
obj.custom_model = ''
}
req(obj).then(() => {
Expand Down Expand Up @@ -526,29 +524,29 @@ const handleEmbedded = (row: any) => {
}
const copyJsCode = () => {
copy(jsCodeElement.value)
.then(function() {
.then(function () {
ElMessage.success(t('embedded.copy_successful'))
})
.catch(function() {
.catch(function () {
ElMessage.error(t('embedded.copy_failed'))
})
}

const copyJsCodeFull = () => {
copy(jsCodeElementFull.value)
.then(function() {
.then(function () {
ElMessage.success(t('embedded.copy_successful'))
})
.catch(function() {
.catch(function () {
ElMessage.error(t('embedded.copy_failed'))
})
}
const copyCode = () => {
copy(scriptElement.value)
.then(function() {
.then(function () {
ElMessage.success(t('embedded.copy_successful'))
})
.catch(function() {
.catch(function () {
ElMessage.error(t('embedded.copy_failed'))
})
}
Expand Down Expand Up @@ -821,13 +819,16 @@ const saveHandler = () => {
</el-form-item>

<el-form-item prop="enable_custom_model">
<el-checkbox v-model="currentEmbedded.enable_custom_model" >
{{t('embedded.enableCustomModel')}}
<el-checkbox v-model="currentEmbedded.enable_custom_model">
{{ t('embedded.enableCustomModel') }}
</el-checkbox>
</el-form-item>

<el-form-item prop="custom_model" :label="t('modelType.llm')"
v-if="currentEmbedded.enable_custom_model">
<el-form-item
v-if="currentEmbedded.enable_custom_model"
prop="custom_model"
:label="t('modelType.llm')"
>
<el-select v-model="currentEmbedded.custom_model" clearable filterable>
<el-option
v-for="item in modelList"
Expand All @@ -837,7 +838,6 @@ const saveHandler = () => {
/>
</el-select>
</el-form-item>

</el-form>
</div>
</el-scrollbar>
Expand Down Expand Up @@ -1013,7 +1013,7 @@ const saveHandler = () => {
<div class="private-list">
{{ t('embedded.set_data_source') }}
<span :title="$t('embedded.open_the_query')" class="open-the_query ellipsis"
>{{ $t('embedded.open_the_query') }}
>{{ $t('embedded.open_the_query') }}
</span>
</div>
</template>
Expand Down
Loading