From 5fddcda2235849bf4d02159f4d254bf3f9deb0eb Mon Sep 17 00:00:00 2001 From: ulleo Date: Thu, 6 Nov 2025 14:54:13 +0800 Subject: [PATCH] feat: Terminology / SQL Sample Management add enabled control --- backend/alembic/env.py | 2 +- backend/alembic/versions/050_modify_ddl_py.py | 37 +++++++++++++++++++ .../apps/data_training/api/data_training.py | 8 +++- .../apps/data_training/curd/data_training.py | 17 ++++++++- .../models/data_training_model.py | 4 +- backend/apps/terminology/api/terminology.py | 7 +++- backend/apps/terminology/curd/terminology.py | 23 ++++++++++-- .../terminology/models/terminology_model.py | 2 + 8 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 backend/alembic/versions/050_modify_ddl_py.py diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 16ef1c3e9..a01ac8365 100755 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -27,7 +27,7 @@ # from apps.chat.models.chat_model import SQLModel from apps.terminology.models.terminology_model import SQLModel #from apps.custom_prompt.models.custom_prompt_model import SQLModel -# from apps.data_training.models.data_training_model import SQLModel +from apps.data_training.models.data_training_model import SQLModel # from apps.dashboard.models.dashboard_model import SQLModel from common.core.config import settings # noqa #from apps.datasource.models.datasource import SQLModel diff --git a/backend/alembic/versions/050_modify_ddl_py.py b/backend/alembic/versions/050_modify_ddl_py.py new file mode 100644 index 000000000..cb312439c --- /dev/null +++ b/backend/alembic/versions/050_modify_ddl_py.py @@ -0,0 +1,37 @@ +"""050_modify_ddl.py + +Revision ID: 2785e54dc1c4 +Revises: b58a71ca6ae3 +Create Date: 2025-11-06 13:43:50.820328 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '2785e54dc1c4' +down_revision = 'b58a71ca6ae3' +branch_labels = None +depends_on = None + +sql=''' +UPDATE data_training SET enabled = true; +UPDATE terminology SET enabled = true; +''' + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('data_training', sa.Column('enabled', sa.Boolean(), nullable=True)) + op.add_column('terminology', sa.Column('enabled', sa.Boolean(), nullable=True)) + + op.execute(sql) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('terminology', 'enabled') + op.drop_column('data_training', 'enabled') + # ### end Alembic commands ### diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py index 25422c2b2..4b2c1ffe5 100644 --- a/backend/apps/data_training/api/data_training.py +++ b/backend/apps/data_training/api/data_training.py @@ -2,7 +2,8 @@ from fastapi import APIRouter, Query -from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training +from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \ + enable_training from apps.data_training.models.data_training_model import DataTrainingInfo from common.core.deps import SessionDep, CurrentUser, Trans @@ -37,3 +38,8 @@ async def create_or_update(session: SessionDep, current_user: CurrentUser, trans @router.delete("") async def delete(session: SessionDep, id_list: list[int]): delete_training(session, id_list) + + +@router.get("{id}/enable/{enabled}") +async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans): + enable_training(session, id, enabled, trans) diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index b9530437a..26c8d5fd0 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -89,7 +89,7 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans if info.datasource is None: raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid, - datasource=info.datasource) + datasource=info.datasource, enabled=info.enabled) exists = session.query( session.query(DataTraining).filter( @@ -135,6 +135,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans question=info.question, description=info.description, datasource=info.datasource, + enabled=info.enabled, ) session.execute(stmt) session.commit() @@ -151,6 +152,20 @@ def delete_training(session: SessionDep, ids: list[int]): session.commit() +def enable_training(session: SessionDep, id: int, enabled: bool, trans: Trans): + count = session.query(DataTraining).filter( + DataTraining.id == id + ).count() + if count == 0: + raise Exception(trans('i18n_data_training.data_training_not_exists')) + + stmt = update(DataTraining).where(and_(DataTraining.id == id)).values( + enabled=enabled, + ) + session.execute(stmt) + session.commit() + + # def run_save_embeddings(ids: List[int]): # executor.submit(save_embeddings, ids) # diff --git a/backend/apps/data_training/models/data_training_model.py b/backend/apps/data_training/models/data_training_model.py index b28064001..24e69be93 100644 --- a/backend/apps/data_training/models/data_training_model.py +++ b/backend/apps/data_training/models/data_training_model.py @@ -3,7 +3,7 @@ from pgvector.sqlalchemy import VECTOR from pydantic import BaseModel -from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean from sqlmodel import SQLModel, Field @@ -16,6 +16,7 @@ class DataTraining(SQLModel, table=True): question: Optional[str] = Field(max_length=255) description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True)) class DataTrainingInfo(BaseModel): @@ -26,3 +27,4 @@ class DataTrainingInfo(BaseModel): create_time: Optional[datetime] = None question: Optional[str] = None description: Optional[str] = None + enabled: Optional[bool] = True diff --git a/backend/apps/terminology/api/terminology.py b/backend/apps/terminology/api/terminology.py index a7eda0839..58544e2e4 100644 --- a/backend/apps/terminology/api/terminology.py +++ b/backend/apps/terminology/api/terminology.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Query from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \ - delete_terminology + delete_terminology, enable_terminology from apps.terminology.models.terminology_model import TerminologyInfo from common.core.deps import SessionDep, CurrentUser, Trans @@ -37,3 +37,8 @@ async def create_or_update(session: SessionDep, current_user: CurrentUser, trans @router.delete("") async def delete(session: SessionDep, id_list: list[int]): delete_terminology(session, id_list) + + +@router.get("{id}/enable/{enabled}") +async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans): + enable_terminology(session, id, enabled, trans) diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 1d6ea4cae..15b5b1640 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -230,7 +230,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra raise Exception(trans("i18n_terminology.datasource_cannot_be_none")) parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid, - specific_ds=specific_ds, + specific_ds=specific_ds, enabled=info.enabled, datasource_ids=datasource_ids) words = [info.word] @@ -289,7 +289,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if other_word.strip() == "": continue _list.append( - Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, + Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, enabled=result.enabled, specific_ds=specific_ds, datasource_ids=datasource_ids)) session.bulk_save_objects(_list) session.flush() @@ -366,7 +366,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra word=info.word, description=info.description, specific_ds=specific_ds, - datasource_ids=datasource_ids + datasource_ids=datasource_ids, + enabled=info.enabled, ) session.execute(stmt) session.commit() @@ -383,7 +384,7 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra continue _list.append( Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid, - specific_ds=specific_ds, datasource_ids=datasource_ids)) + specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled)) session.bulk_save_objects(_list) session.flush() session.commit() @@ -400,6 +401,20 @@ def delete_terminology(session: SessionDep, ids: list[int]): session.commit() +def enable_terminology(session: SessionDep, id: int, enabled: bool, trans: Trans): + count = session.query(Terminology).filter( + Terminology.id == id + ).count() + if count == 0: + raise Exception(trans('i18n_terminology.terminology_not_exists')) + + stmt = update(Terminology).where(or_(Terminology.id == id, Terminology.pid == id)).values( + enabled=enabled, + ) + session.execute(stmt) + session.commit() + + # def run_save_embeddings(ids: List[int]): # executor.submit(save_embeddings, ids) # diff --git a/backend/apps/terminology/models/terminology_model.py b/backend/apps/terminology/models/terminology_model.py index b90486593..850aadabe 100644 --- a/backend/apps/terminology/models/terminology_model.py +++ b/backend/apps/terminology/models/terminology_model.py @@ -19,6 +19,7 @@ class Terminology(SQLModel, table=True): embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False)) datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[]) + enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True)) class TerminologyInfo(BaseModel): @@ -30,3 +31,4 @@ class TerminologyInfo(BaseModel): specific_ds: Optional[bool] = False datasource_ids: Optional[list[int]] = [] datasource_names: Optional[list[str]] = [] + enabled: Optional[bool] = True