diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index e3f5444efe38..e9a1976e28cf 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -184,7 +184,9 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, FReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); - TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + if (ndim == 0) { + return topi::identity(data, data->op->name + "_red", kCommReduce); + } auto real_axis = GetRealAxis(static_cast(ndim), axis); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); return DoCommReduce(data, func, target_shape, real_axis, diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 1a0b71690d37..82c478bd5168 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -1179,5 +1179,37 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype=" tvm.ir.assert_structural_equal(mod, Expected) +def test_max_zero_dim(): + # Reducing a 0-D (scalar) tensor is the identity; it must legalize, not crash. + # Regression test for https://github.com/apache/tvm/issues/19676 + # fmt: off + @tvm.script.ir_module + class Max: + @R.function + def main(x: R.Tensor((), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.max(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((), dtype="float32"): + gv = R.call_tir(Expected.max, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func(private=True, s_tir=True) + def max(x: T.Buffer((), "float32"), x_red: T.Buffer((), "float32")): + T.func_attr({"tirx.noalias": True}) + with T.sblock("x_red"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(x[()]) + T.writes(x_red[()]) + x_red[()] = x[()] + # fmt: on + + mod = LegalizeOps()(Max) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()