diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index 26c8d5fd0..7bc51b15e 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -62,6 +62,7 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in DataTraining.question, DataTraining.create_time, DataTraining.description, + DataTraining.enabled, ) .outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id)) .where(and_(DataTraining.id.in_(paginated_parent_ids))) @@ -79,6 +80,7 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in question=row.question, create_time=row.create_time, description=row.description, + enabled=row.enabled, )) return current_page, page_size, total_count, total_pages, _list @@ -221,11 +223,11 @@ def save_embeddings(session_maker, ids: List[int]): embedding_sql = f""" SELECT id, datasource, question, similarity FROM -(SELECT id, datasource, question, oid, +(SELECT id, datasource, question, oid, enabled, ( 1 - (embedding <=> :embedding_array) ) AS similarity FROM data_training AS child ) TEMP -WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and datasource = :datasource +WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and datasource = :datasource and enabled = true ORDER BY similarity DESC LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT} """ @@ -246,7 +248,8 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da .where( and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")), DataTraining.oid == oid, - DataTraining.datasource == datasource) + DataTraining.datasource == datasource, + DataTraining.enabled == True,) ) ) diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 15b5b1640..8720bff70 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -99,7 +99,8 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.specific_ds, Terminology.datasource_ids, children_subquery.c.other_words, - func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'), + Terminology.enabled ) .outerjoin( children_subquery, @@ -122,7 +123,8 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.description, Terminology.specific_ds, Terminology.datasource_ids, - children_subquery.c.other_words + children_subquery.c.other_words, + Terminology.enabled ) .order_by(Terminology.create_time.desc()) ) @@ -175,7 +177,8 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.specific_ds, Terminology.datasource_ids, children_subquery.c.other_words, - func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'), + Terminology.enabled ) .outerjoin( children_subquery, @@ -197,7 +200,8 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.description, Terminology.specific_ds, Terminology.datasource_ids, - children_subquery.c.other_words + children_subquery.c.other_words, + Terminology.enabled ) .order_by(Terminology.create_time.desc()) ) @@ -214,6 +218,7 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int specific_ds=row.specific_ds if row.specific_ds is not None else False, datasource_ids=row.datasource_ids if row.datasource_ids is not None else [], datasource_names=row.datasource_names if row.datasource_names is not None else [], + enabled=row.enabled if row.enabled is not None else False, )) return current_page, page_size, total_count, total_pages, _list @@ -474,11 +479,11 @@ def save_embeddings(session_maker, ids: List[int]): embedding_sql = f""" SELECT id, pid, word, similarity FROM -(SELECT id, pid, word, oid, specific_ds, datasource_ids, +(SELECT id, pid, word, oid, specific_ds, datasource_ids, enabled, ( 1 - (embedding <=> :embedding_array) ) AS similarity FROM terminology AS child ) TEMP -WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid +WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid AND enabled = true AND (specific_ds = false OR specific_ds IS NULL) ORDER BY similarity DESC LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT} @@ -487,11 +492,11 @@ def save_embeddings(session_maker, ids: List[int]): embedding_sql_with_datasource = f""" SELECT id, pid, word, similarity FROM -(SELECT id, pid, word, oid, specific_ds, datasource_ids, +(SELECT id, pid, word, oid, specific_ds, datasource_ids, enabled, ( 1 - (embedding <=> :embedding_array) ) AS similarity FROM terminology AS child ) TEMP -WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid +WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid AND enabled = true AND ( (specific_ds = false OR specific_ds IS NULL) OR @@ -515,7 +520,7 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasou Terminology.word, ) .where( - and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid) + and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid, Terminology.enabled == True) ) )