Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError

@abstractmethod
def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
raise NotImplementedError

@abstractmethod
Expand All @@ -96,7 +96,7 @@ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None

@abstractmethod
def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
raise NotImplementedError

Expand All @@ -113,12 +113,12 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
raise NotImplementedError

@abstractmethod
def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
raise NotImplementedError

@abstractmethod
def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
raise NotImplementedError

Expand Down
11 changes: 8 additions & 3 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop
self.tokenizer = llm_client.tokenizer

async def extract(
self, chunk: Chunk
Expand Down Expand Up @@ -134,6 +135,7 @@ async def merge_nodes(
"entity_name": entity_name,
"description": description,
"source_id": source_id,
"length": self.tokenizer.count_tokens(description),
}
kg_instance.upsert_node(entity_name, node_data=node_data)
return node_data
Expand Down Expand Up @@ -167,9 +169,11 @@ async def merge_edges(
kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"entity_name": insert_id,
"description": "",
"source_id": source_id,
"length": self.tokenizer.count_tokens(description),
},
)

Expand All @@ -182,12 +186,13 @@ async def merge_edges(
"tgt_id": tgt_id,
"description": description,
"source_id": source_id, # for traceability
"length": self.tokenizer.count_tokens(description),
}

kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
edge_data=edge_data,
)
return edge_data

Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _add_unit(u):
return False
community_edges[i] = d
used_e.add(i)
token_sum += d.get("length", 0)
token_sum += int(d.get("length", 0))
return True

_add_unit(seed_unit)
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/storage/graph/kuzu_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_node(self, node_id: str) -> Any:
data_str = result.get_next()[0]
return self._safe_json_loads(data_str)

def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
current_data = self.get_node(node_id)
if current_data is None:
print(f"Node {node_id} not found for update.")
Expand Down Expand Up @@ -263,7 +263,7 @@ def get_edge(self, source_node_id: str, target_node_id: str):
return self._safe_json_loads(data_str)

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
current_data = self.get_edge(source_node_id, target_node_id)
if current_data is None:
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_node_edges(self, source_node_id: str) -> Any:
edges.append((src, dst, data))
return edges

def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
"""
Insert or Update node.
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
Expand All @@ -336,7 +336,7 @@ def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._conn.execute(query, {"id": node_id, "data": json_data})

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
"""
Insert or Update edge.
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/storage/graph/networkx_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,22 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
def get_graph(self) -> nx.Graph:
return self._graph

def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
self._graph.add_node(node_id, **node_data)

def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
if self._graph.has_node(node_id):
self._graph.nodes[node_id].update(node_data)
else:
print(f"Node {node_id} not found in the graph for update.")

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
if self._graph.has_edge(source_node_id, target_node_id):
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
Expand Down
41 changes: 2 additions & 39 deletions graphgen/operators/partition/partition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def partition(self) -> Iterable[pd.DataFrame]:
partitioner = DFSPartitioner()
elif method == "ece":
logger.info("Partitioning knowledge graph using ECE method.")
# TODO: before ECE partitioning, we need to:
# 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
# 2. pre-tokenize nodes and edges to get the token length
self._pre_tokenize()
# before ECE partitioning, we need to:
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
partitioner = ECEPartitioner()
elif method == "leiden":
logger.info("Partitioning knowledge graph using Leiden method.")
Expand Down Expand Up @@ -97,41 +95,6 @@ def partition(self) -> Iterable[pd.DataFrame]:
)
logger.info("Total communities partitioned: %d", count)

def _pre_tokenize(self) -> None:
"""Pre-tokenize all nodes and edges to add token length information."""
logger.info("Starting pre-tokenization of nodes and edges...")

nodes = self.kg_instance.get_all_nodes()
edges = self.kg_instance.get_all_edges()

# Process nodes
for node_id, node_data in nodes:
if "length" not in node_data:
try:
description = node_data.get("description", "")
tokens = self.tokenizer_instance.encode(description)
node_data["length"] = len(tokens)
self.kg_instance.update_node(node_id, node_data)
except Exception as e:
logger.warning("Failed to tokenize node %s: %s", node_id, e)
node_data["length"] = 0

# Process edges
for u, v, edge_data in edges:
if "length" not in edge_data:
try:
description = edge_data.get("description", "")
tokens = self.tokenizer_instance.encode(description)
edge_data["length"] = len(tokens)
self.kg_instance.update_edge(u, v, edge_data)
except Exception as e:
logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e)
edge_data["length"] = 0

# Persist changes
self.kg_instance.index_done_callback()
logger.info("Pre-tokenization completed.")

def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
"""
Attach additional data from chunk_storage to nodes in the batch.
Expand Down