Skip to content
Open
6 changes: 6 additions & 0 deletions python/tvm/relax/transform/legalize_ops/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
b_relax = relax.Var("b", relax.TensorStructInfo(b.shape))
f_infer_sinfo = call.op.get_attr("FInferStructInfo")
output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape
if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0:
return te.compute(
output_shape,
lambda *_: tirx.const(0, call.struct_info.dtype),
name="matmul",
)

def matmul_compute(*idx_spatial):
k = te.reduce_axis((0, a_shape[-1]), name="k")
Expand Down
53 changes: 49 additions & 4 deletions python/tvm/relax/transform/legalize_ops/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,57 @@
# pylint: disable=invalid-name
"""Default legalization function for statistical operators."""

from collections.abc import Callable

from tvm import te, tirx, topi

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from ...expr import Call, Expr, ShapeExpr
from .common import LegalizeFunc, TEFunc, register_legalize


def _statistical(te_func: TEFunc) -> LegalizeFunc:
def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]:
if axis is None:
return list(range(ndim))

axes = []
for dim in axis:
if isinstance(dim, tirx.IntImm):
dim = dim.value
dim = int(dim)
axes.append(dim + ndim if dim < 0 else dim)
return axes


def _has_const_zero_reduction_dim(call: Call) -> bool:
input_shape = call.args[0].struct_info.shape
if not isinstance(input_shape, ShapeExpr):
return False

axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values))
return any(
isinstance(input_shape.values[dim], tirx.IntImm) and input_shape.values[dim] == 0
for dim in axes
)


def _statistical(
te_func: TEFunc,
zero_dim_identity: int | float | bool | Callable[[str], int | float | bool] | None = None,
) -> LegalizeFunc:
def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr:
if zero_dim_identity is not None and _has_const_zero_reduction_dim(call):
fill_value = (
zero_dim_identity(call.struct_info.dtype)
if callable(zero_dim_identity)
else zero_dim_identity
)
return bb.call_te(
topi.full,
call.struct_info.shape.values,
call.struct_info.dtype,
fill_value,
)
return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims)

return statistical_call_te
Expand Down Expand Up @@ -129,5 +171,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr:

register_legalize("relax.max", _statistical(topi.max))
register_legalize("relax.min", _statistical(topi.min))
register_legalize("relax.prod", _statistical(topi.prod))
register_legalize("relax.sum", _statistical(topi.sum))
register_legalize(
"relax.prod",
_statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == "bool" else 1),
)
register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0))
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,22 @@ def main(x: R.Tensor((1, 1, 4, 5), dtype="float32"), y: R.Tensor((1, 1, 5, 7), d
tvm.ir.assert_structural_equal(mod, Expected)


def test_matmul_zero_k_no_reduction():
# fmt: off
@tvm.script.ir_module
class Matmul:
@R.function
def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.matmul(x, y)
return gv
# fmt: on

mod = LegalizeOps()(Matmul)
script = mod.script()
assert "T.axis.reduce" not in script
assert "T.float32(0)" in script or "T.float32(0.0)" in script


def test_einsum():
# fmt: off
@I.ir_module(s_tir=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,70 @@ def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T
tvm.ir.assert_structural_equal(mod, Expected)


def test_sum_zero_dim_axis_identity():
# fmt: off
@tvm.script.ir_module
class Sum:
@R.function
def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"):
gv: R.Tensor((2, 4), "float32") = R.sum(x, axis=[1], keepdims=False)
return gv
# fmt: on

mod = LegalizeOps()(Sum)
script = mod.script()
assert "T.axis.reduce" not in script
assert "T.float32(0)" in script or "T.float32(0.0)" in script


def test_sum_zero_dim_negative_axis_identity():
# fmt: off
@tvm.script.ir_module
class Sum:
@R.function
def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.sum(x, axis=[-1], keepdims=False)
return gv
# fmt: on

mod = LegalizeOps()(Sum)
script = mod.script()
assert "T.axis.reduce" not in script
assert "T.float32(0)" in script or "T.float32(0.0)" in script


def test_prod_zero_dim_axis_identity():
# fmt: off
@tvm.script.ir_module
class Prod:
@R.function
def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"):
gv: R.Tensor((2, 4), "float32") = R.prod(x, axis=[1], keepdims=False)
return gv
# fmt: on

mod = LegalizeOps()(Prod)
script = mod.script()
assert "T.axis.reduce" not in script
assert "T.float32(1)" in script or "T.float32(1.0)" in script


def test_prod_bool_zero_dim_axis_identity():
# fmt: off
@tvm.script.ir_module
class Prod:
@R.function
def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"):
gv: R.Tensor((2, 4), "bool") = R.prod(x, axis=[1], keepdims=False)
return gv
# fmt: on

mod = LegalizeOps()(Prod)
script = mod.script()
assert "T.axis.reduce" not in script
assert "T.bool(1)" in script or "T.bool(True)" in script


def test_mean():
# fmt: off
@tvm.script.ir_module
Expand Down
Loading