Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def get_node_shape(self, node_or_name):
if "shape" in node.attr:
return [dim.size for dim in node.attr["shape"].shape.dim]

if node.op == "Const" and "value" in node.attr:
tensor = node.attr["value"].tensor
if tensor.HasField("tensor_shape"):
return [d.size for d in tensor.tensor_shape.dim]

return None

def get_node_rank(self, node_or_name):
Expand Down Expand Up @@ -1136,9 +1141,11 @@ def CommutativeOp(

def ConstValue(value, alias=None):
"""Matches a Const node with a specific value."""
import numpy as np

def check_value(unwrapped_value):
return unwrapped_value == value
# Use np.all() for element-wise comparison on arrays
return np.all(np.equal(unwrapped_value, value))

return Op("Const", attrs={"value": check_value}, alias=alias)

Expand Down
Loading