From 8502a268ebfb4233629f24cc28c427f4c1f8b6c4 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 00:38:25 +0800 Subject: [PATCH 1/9] [Relax] Fix matmul and reductions with zero-size dimension return uninitialized memory --- .../transform/legalize_ops/linear_algebra.py | 2 + .../transform/legalize_ops/statistical.py | 53 +++++++++++++++++-- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index d8dd8aa3b0cc..2fabd7151838 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -45,6 +45,8 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) f_infer_sinfo = call.op.get_attr("FInferStructInfo") output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0: + return topi.full(output_shape, call.struct_info.dtype, 0) def matmul_compute(*idx_spatial): k = te.reduce_axis((0, a_shape[-1]), name="k") diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index cbad62e44810..168cd7139997 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -17,15 +17,57 @@ # pylint: disable=invalid-name """Default legalization function for statistical operators.""" +from collections.abc import Callable + from tvm import te, tirx, topi from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Call, Expr, ShapeExpr from .common import LegalizeFunc, TEFunc, register_legalize -def _statistical(te_func: TEFunc) -> LegalizeFunc: +def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]: + if axis is None: + return list(range(ndim)) + + axes = [] + for dim in axis: + if isinstance(dim, tirx.IntImm): + dim = dim.value + dim = int(dim) + axes.append(dim + ndim if dim < 0 else dim) + return axes + + +def _has_const_zero_reduction_dim(call: Call) -> bool: + input_shape = call.args[0].struct_info.shape + if not isinstance(input_shape, ShapeExpr): + return False + + axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values)) + return any( + isinstance(input_shape.values[dim], tirx.IntImm) and input_shape.values[dim] == 0 + for dim in axes + ) + + +def _statistical( + te_func: TEFunc, + zero_dim_identity: int | float | bool | Callable[[str], int | float | bool] | None = None, +) -> LegalizeFunc: def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: + if zero_dim_identity is not None and _has_const_zero_reduction_dim(call): + fill_value = ( + zero_dim_identity(call.struct_info.dtype) + if callable(zero_dim_identity) + else zero_dim_identity + ) + return bb.call_te( + topi.full, + call.struct_info.shape.values, + call.struct_info.dtype, + fill_value, + ) return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) return statistical_call_te @@ -129,5 +171,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.max", _statistical(topi.max)) register_legalize("relax.min", _statistical(topi.min)) -register_legalize("relax.prod", _statistical(topi.prod)) -register_legalize("relax.sum", _statistical(topi.sum)) +register_legalize( + "relax.prod", + _statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == "bool" else 1), +) +register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0)) From a732bdd512d20768dca359d373abd6bc3d7ed298 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 14:24:26 +0800 Subject: [PATCH 2/9] [Relax] Add test case: test_matmul_zero_k_no_reduction --- ...ransform_legalize_ops_index_linear_algebra.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index dbd92ba6d378..9a8b61997233 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -1136,6 +1136,22 @@ def main(x: R.Tensor((1, 1, 4, 5), dtype="float32"), y: R.Tensor((1, 1, 5, 7), d tvm.ir.assert_structural_equal(mod, Expected) +def test_matmul_zero_k_no_reduction(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.matmul(x, y) + return gv + # fmt: on + + mod = LegalizeOps()(Matmul) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script + + def test_einsum(): # fmt: off @I.ir_module(s_tir=True) From 853dafc2c33512a6256a91bb702689ab6d82043c Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 15:40:21 +0800 Subject: [PATCH 3/9] [Relax] Change topi.full(output_shape, ...) to te.compute(output_shape, ...) to avoid ShapeExpr -> Array type error --- python/tvm/relax/transform/legalize_ops/linear_algebra.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index 2fabd7151838..2b4d1efd108f 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -46,7 +46,11 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: f_infer_sinfo = call.op.get_attr("FInferStructInfo") output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0: - return topi.full(output_shape, call.struct_info.dtype, 0) + return te.compute( + output_shape, + lambda *_: tirx.const(0, call.struct_info.dtype), + name="matmul", + ) def matmul_compute(*idx_spatial): k = te.reduce_axis((0, a_shape[-1]), name="k") From 20262d3ba9294bb9e6aa5350be3f609069b37532 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 17:53:59 +0800 Subject: [PATCH 4/9] [Relax] update the test case assertion --- .../relax/test_transform_legalize_ops_index_linear_algebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 9a8b61997233..9b905dd3da30 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -1149,7 +1149,7 @@ def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3), "float32")) -> R.Te mod = LegalizeOps()(Matmul) script = mod.script() assert "T.axis.reduce" not in script - assert "T.float32(0)" in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script def test_einsum(): From e1206854919eef46336c7600094994e2a011b568 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 19:47:54 +0800 Subject: [PATCH 5/9] [Relax] Add test case: test_sum_zero_dim_axis_identity --- ..._transform_legalize_ops_search_statistical.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 1a0b71690d37..768a7c9cfdbd 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -629,6 +629,22 @@ def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T tvm.ir.assert_structural_equal(mod, Expected) +def test_sum_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.sum(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Sum) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script + + def test_mean(): # fmt: off @tvm.script.ir_module From c06dd7ada85cac44939b2bc42b26ce49fcd1595d Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 21:19:35 +0800 Subject: [PATCH 6/9] [Relax] Add test case: test_sum_zero_dim_negative_axis_identity --- ..._transform_legalize_ops_search_statistical.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 768a7c9cfdbd..e4a67a7af6ad 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -645,6 +645,22 @@ def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): assert "T.float32(0)" in script or "T.float32(0.0)" in script +def test_sum_zero_dim_negative_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sum(x, axis=[-1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Sum) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(0)" in script or "T.float32(0.0)" in script + + def test_mean(): # fmt: off @tvm.script.ir_module From 58bfa3ef379362efaa208b451cbf57f69d233a0f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 23:05:19 +0800 Subject: [PATCH 7/9] [Relax] Add test case: test_prod_zero_dim_axis_identity --- ..._transform_legalize_ops_search_statistical.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index e4a67a7af6ad..02ad1d818037 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -661,6 +661,22 @@ def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3), "float32"): assert "T.float32(0)" in script or "T.float32(0.0)" in script +def test_prod_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.prod(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Prod) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.float32(1)" in script or "T.float32(1.0)" in script + + def test_mean(): # fmt: off @tvm.script.ir_module From 7b774f4b76b475c480b018a8b5760ae4aae440b5 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 7 Jun 2026 23:14:03 +0800 Subject: [PATCH 8/9] [Relax] Add test case: test_prod_bool_zero_dim_axis_identity --- ..._transform_legalize_ops_search_statistical.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 02ad1d818037..d8b95a141968 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -677,6 +677,22 @@ def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), "float32"): assert "T.float32(1)" in script or "T.float32(1.0)" in script +def test_prod_bool_zero_dim_axis_identity(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"): + gv: R.Tensor((2, 4), "bool") = R.prod(x, axis=[1], keepdims=False) + return gv + # fmt: on + + mod = LegalizeOps()(Prod) + script = mod.script() + assert "T.axis.reduce" not in script + assert "T.bool(1)" in script + + def test_mean(): # fmt: off @tvm.script.ir_module From bf00afe769d5f3cf37980f685f29e88f5ace2764 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Mon, 8 Jun 2026 09:59:48 +0800 Subject: [PATCH 9/9] Update assertion for T.bool in legalize ops test --- .../relax/test_transform_legalize_ops_search_statistical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index d8b95a141968..3244d7134e72 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -690,7 +690,7 @@ def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"): mod = LegalizeOps()(Prod) script = mod.script() assert "T.axis.reduce" not in script - assert "T.bool(1)" in script + assert "T.bool(1)" in script or "T.bool(True)" in script def test_mean():