diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index e72c5869..dd9cf151 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -79,7 +79,7 @@ def get_node(self, node_id: str) -> Union[dict, None]: raise NotImplementedError @abstractmethod - def update_node(self, node_id: str, node_data: dict[str, str]): + def update_node(self, node_id: str, node_data: dict[str, any]): raise NotImplementedError @abstractmethod @@ -96,7 +96,7 @@ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None @abstractmethod def update_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): raise NotImplementedError @@ -113,12 +113,12 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No raise NotImplementedError @abstractmethod - def upsert_node(self, node_id: str, node_data: dict[str, str]): + def upsert_node(self, node_id: str, node_data: dict[str, any]): raise NotImplementedError @abstractmethod def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): raise NotImplementedError diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index c3e7345c..53fe1d66 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -18,6 +18,7 @@ class LightRAGKGBuilder(BaseKGBuilder): def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3): super().__init__(llm_client) self.max_loop = max_loop + self.tokenizer = llm_client.tokenizer async def extract( self, chunk: Chunk @@ -134,6 +135,7 @@ async def merge_nodes( "entity_name": entity_name, "description": description, "source_id": source_id, + "length": self.tokenizer.count_tokens(description), } kg_instance.upsert_node(entity_name, node_data=node_data) return node_data @@ -167,9 +169,11 @@ async def merge_edges( kg_instance.upsert_node( insert_id, node_data={ - "source_id": source_id, - "description": description, "entity_type": "UNKNOWN", + "entity_name": insert_id, + "description": "", + "source_id": source_id, + "length": self.tokenizer.count_tokens(description), }, ) @@ -182,12 +186,13 @@ async def merge_edges( "tgt_id": tgt_id, "description": description, "source_id": source_id, # for traceability + "length": self.tokenizer.count_tokens(description), } kg_instance.upsert_edge( src_id, tgt_id, - edge_data={"source_id": source_id, "description": description}, + edge_data=edge_data, ) return edge_data diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 46765fd7..733f6ea1 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -99,7 +99,7 @@ def _add_unit(u): return False community_edges[i] = d used_e.add(i) - token_sum += d.get("length", 0) + token_sum += int(d.get("length", 0)) return True _add_unit(seed_unit) diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/models/storage/graph/kuzu_storage.py index 52b41519..f2acf40c 100644 --- a/graphgen/models/storage/graph/kuzu_storage.py +++ b/graphgen/models/storage/graph/kuzu_storage.py @@ -215,7 +215,7 @@ def get_node(self, node_id: str) -> Any: data_str = result.get_next()[0] return self._safe_json_loads(data_str) - def update_node(self, node_id: str, node_data: dict[str, str]): + def update_node(self, node_id: str, node_data: dict[str, any]): current_data = self.get_node(node_id) if current_data is None: print(f"Node {node_id} not found for update.") @@ -263,7 +263,7 @@ def get_edge(self, source_node_id: str, target_node_id: str): return self._safe_json_loads(data_str) def update_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): current_data = self.get_edge(source_node_id, target_node_id) if current_data is None: @@ -318,7 +318,7 @@ def get_node_edges(self, source_node_id: str) -> Any: edges.append((src, dst, data)) return edges - def upsert_node(self, node_id: str, node_data: dict[str, str]): + def upsert_node(self, node_id: str, node_data: dict[str, any]): """ Insert or Update node. Kuzu supports MERGE clause (similar to Neo4j) to handle upserts. @@ -336,7 +336,7 @@ def upsert_node(self, node_id: str, node_data: dict[str, str]): self._conn.execute(query, {"id": node_id, "data": json_data}) def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): """ Insert or Update edge. diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py index b043e9d2..4635eb61 100644 --- a/graphgen/models/storage/graph/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -144,22 +144,22 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No def get_graph(self) -> nx.Graph: return self._graph - def upsert_node(self, node_id: str, node_data: dict[str, str]): + def upsert_node(self, node_id: str, node_data: dict[str, any]): self._graph.add_node(node_id, **node_data) - def update_node(self, node_id: str, node_data: dict[str, str]): + def update_node(self, node_id: str, node_data: dict[str, any]): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: print(f"Node {node_id} not found in the graph for update.") def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): self._graph.add_edge(source_node_id, target_node_id, **edge_data) def update_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, any] ): if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index 8985e5b8..ff215fce 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -60,10 +60,8 @@ def partition(self) -> Iterable[pd.DataFrame]: partitioner = DFSPartitioner() elif method == "ece": logger.info("Partitioning knowledge graph using ECE method.") - # TODO: before ECE partitioning, we need to: - # 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random - # 2. pre-tokenize nodes and edges to get the token length - self._pre_tokenize() + # before ECE partitioning, we need to: + # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random partitioner = ECEPartitioner() elif method == "leiden": logger.info("Partitioning knowledge graph using Leiden method.") @@ -97,41 +95,6 @@ def partition(self) -> Iterable[pd.DataFrame]: ) logger.info("Total communities partitioned: %d", count) - def _pre_tokenize(self) -> None: - """Pre-tokenize all nodes and edges to add token length information.""" - logger.info("Starting pre-tokenization of nodes and edges...") - - nodes = self.kg_instance.get_all_nodes() - edges = self.kg_instance.get_all_edges() - - # Process nodes - for node_id, node_data in nodes: - if "length" not in node_data: - try: - description = node_data.get("description", "") - tokens = self.tokenizer_instance.encode(description) - node_data["length"] = len(tokens) - self.kg_instance.update_node(node_id, node_data) - except Exception as e: - logger.warning("Failed to tokenize node %s: %s", node_id, e) - node_data["length"] = 0 - - # Process edges - for u, v, edge_data in edges: - if "length" not in edge_data: - try: - description = edge_data.get("description", "") - tokens = self.tokenizer_instance.encode(description) - edge_data["length"] = len(tokens) - self.kg_instance.update_edge(u, v, edge_data) - except Exception as e: - logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e) - edge_data["length"] = 0 - - # Persist changes - self.kg_instance.index_done_callback() - logger.info("Pre-tokenization completed.") - def _attach_additional_data_to_node(self, batch: tuple) -> tuple: """ Attach additional data from chunk_storage to nodes in the batch.