diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index d8dd8aa3b0cc..2b4d1efd108f 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -45,6 +45,12 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) f_infer_sinfo = call.op.get_attr("FInferStructInfo") output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0: + return te.compute( + output_shape, + lambda *_: tirx.const(0, call.struct_info.dtype), + name="matmul", + ) def matmul_compute(*idx_spatial): k = te.reduce_axis((0, a_shape[-1]), name="k") diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index cbad62e44810..168cd7139997 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -17,15 +17,57 @@ # pylint: disable=invalid-name """Default legalization function for statistical operators.""" +from collections.abc import Callable + from tvm import te, tirx, topi from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Call, Expr, ShapeExpr from .common import LegalizeFunc, TEFunc, register_legalize -def _statistical(te_func: TEFunc) -> LegalizeFunc: +def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]: + if axis is None: + return list(range(ndim)) + + axes = [] + for dim in axis: + if isinstance(dim, tirx.IntImm): + dim = dim.value + dim = int(dim) + axes.append(dim + ndim if dim < 0 else dim) + return axes + + +def _has_const_zero_reduction_dim(call: Call) -> bool: + input_shape = call.args[0].struct_info.shape + if not isinstance(input_shape, ShapeExpr): + return False + + axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values)) + return any( + isinstance(input_shape.values[dim], tirx.IntImm) and input_shape.values[dim] == 0 + for dim in axes + ) + + +def _statistical( + te_func: TEFunc, + zero_dim_identity: int | float | bool | Callable[[str], int | float | bool] | None = None, +) -> LegalizeFunc: def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: + if zero_dim_identity is not None and _has_const_zero_reduction_dim(call): + fill_value = ( + zero_dim_identity(call.struct_info.dtype) + if callable(zero_dim_identity) + else zero_dim_identity + ) + return bb.call_te( + topi.full, + call.struct_info.shape.values, + call.struct_info.dtype, + fill_value, + ) return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) return statistical_call_te @@ -129,5 +171,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.max", _statistical(topi.max)) register_legalize("relax.min", _statistical(topi.min)) -register_legalize("relax.prod", _statistical(topi.prod)) -register_legalize("relax.sum", _statistical(topi.sum)) +register_legalize( + "relax.prod", + _statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == "bool" else 1), +) +register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0)) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index dbd92ba6d378..9b905dd3da30 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -1136,6 +1136,22 @@ def main(x: R.Tensor((1, 1, 4, 5), dtype="float32"), y: R.Tensor((1, 1, 5, 7), d tvm.ir.assert_structural_equal(mod, Expected) +def test_matmul_zero_k_no_reduction(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.matmul(x, y) + return gv + # fmt: on + + mod = LegalizeOps()(Matmul) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script + + def test_einsum(): # fmt: off @I.ir_module(s_tir=True) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 1a0b71690d37..3244d7134e72 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -629,6 +629,70 @@ def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T tvm.ir.assert_structural_equal(mod, Expected) +def test_sum_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.sum(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Sum) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script + + +def test_sum_zero_dim_negative_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sum(x, axis=[-1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Sum) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script + + +def test_prod_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.prod(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Prod) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(1)" in script or "T.float32(1.0)" in script + + +def test_prod_bool_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"): + gv: R.Tensor((2, 4), "bool") = R.prod(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Prod) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.bool(1)" in script or "T.bool(True)" in script + + def test_mean(): # fmt: off @tvm.script.ir_module