Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions transforms/scalar/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@
"""

from collections import defaultdict

from ...core import BasePass, PassRegistry
from ...utils.hash_utils import hash_tensor_value
from ...utils.logger import logger as logging
from ...core import PassRegistry, BasePass


def extract_key_attrs(attrs, op_type=None):
Expand Down Expand Up @@ -165,9 +167,9 @@ def extract_key_attrs(attrs, op_type=None):
key_attrs.append((attr_name, "shape", shape_dims))
elif attr_value.HasField("tensor"):
# tensor 类型(Const 节点的 value 属性)
# 序列化为字节串确保相同值的常量有相同签名
tensor_bytes = attr_value.tensor.SerializeToString()
key_attrs.append((attr_name, "tensor", tensor_bytes))
# 使用哈希值代替完整的序列化字节串,以优化性能
tensor_hash = hash_tensor_value(attr_value.tensor)
key_attrs.append((attr_name, "tensor", tensor_hash))
elif attr_value.HasField("func"):
# 函数引用(如 While 循环的 body/cond)
key_attrs.append((attr_name, "func", attr_value.func.name))
Expand Down Expand Up @@ -196,7 +198,7 @@ def extract_key_attrs(attrs, op_type=None):
key_attrs.append((attr_name, "list_shape", shapes))
elif attr_value.list.tensor:
# list of tensor
tensors = tuple(t.SerializeToString() for t in attr_value.list.tensor)
tensors = tuple(hash_tensor_value(t) for t in attr_value.list.tensor)
key_attrs.append((attr_name, "list_tensor", tensors))
elif attr_value.list.func:
# list of func
Expand Down
16 changes: 16 additions & 0 deletions utils/hash_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import hashlib

def hash_tensor_value(tensor_proto):
"""
Computes a hash of the tensor's value for efficient comparison.

Args:
tensor_proto: The TensorProto object.

Returns:
A string hash of the tensor's value.
"""
# Using MD5 for speed. It's a performance optimization, not a security feature.
hasher = hashlib.md5()
hasher.update(tensor_proto.SerializeToString())
return hasher.hexdigest()