diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index b5c92f40a..e8a3d4de8 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1038,32 +1038,45 @@ def delete_node_by_prams( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" ) - # First count matching nodes to get accurate count - count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count" - logger.info(f"[delete_node_by_prams] count_query: {count_query}") - print(f"[delete_node_by_prams] count_query: {count_query}") + # First collect IDs of matching nodes (for both Neo4j deletion and Qdrant cleanup) + ids_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN n.id AS node_id" + logger.info(f"[delete_node_by_prams] ids_query: {ids_query}") + print(f"[delete_node_by_prams] ids_query: {ids_query}") - # Then delete nodes + # Then delete nodes from Neo4j delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") print(f"[delete_node_by_prams] delete_query: {delete_query}") print(f"[delete_node_by_prams] params: {params}") deleted_count = 0 + deleted_ids: list[str] = [] try: with self.driver.session(database=self.db_name) as session: - # Count nodes before deletion - count_result = session.run(count_query, **params) - count_record = count_result.single() - expected_count = 0 - if count_record: - expected_count = count_record["node_count"] or 0 - - # Delete nodes + # Collect IDs before deletion + ids_result = session.run(ids_query, **params) + for rec in ids_result: + node_id = rec.get("node_id") + if node_id: + deleted_ids.append(str(node_id)) + expected_count = len(deleted_ids) + + # Delete nodes from Neo4j session.run(delete_query, **params) - # Use the count from before deletion as the actual deleted count deleted_count = expected_count + # After successful Neo4j deletion, clean up Qdrant vectors + if deleted_ids: + try: + self.vec_db.delete(deleted_ids) + logger.info( + f"[delete_node_by_prams] Deleted {len(deleted_ids)} vectors from Qdrant" + ) + except Exception as vec_err: + logger.warning( + f"[delete_node_by_prams] Failed to delete vectors from Qdrant: {vec_err}" + ) + except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise