Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func,
inline Tensor CommReduce(const Tensor& data, const ffi::Optional<ffi::Array<int64_t>>& axis,
FReduce func, bool keepdims, bool atleast1d) {
auto ndim = data->shape.size();
TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
if (ndim == 0) {
return topi::identity(data, data->op->name + "_red", kCommReduce);
}
Comment on lines +187 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When ndim == 0 and atleast1d is true, the output tensor should be at least 1-dimensional (i.e., shape [1]). However, returning topi::identity directly ignores the atleast1d flag and returns a 0-dimensional tensor. We should wrap the identity tensor with topi::expand_dims if atleast1d is enabled.

Suggested change
if (ndim == 0) {
return topi::identity(data, data->op->name + "_red", kCommReduce);
}
if (ndim == 0) {
auto identity = topi::identity(data, data->op->name + "_red", kCommReduce);
return atleast1d ? topi::expand_dims(identity, 0, 1) : identity;
}

auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
return DoCommReduce(data, func, target_shape, real_axis,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1179,5 +1179,37 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="
tvm.ir.assert_structural_equal(mod, Expected)


def test_max_zero_dim():
# Reducing a 0-D (scalar) tensor is the identity; it must legalize, not crash.
# Regression test for https://github.com/apache/tvm/issues/19676
# fmt: off
@tvm.script.ir_module
class Max:
@R.function
def main(x: R.Tensor((), "float32")) -> R.Tensor((), "float32"):
gv: R.Tensor((), "float32") = R.max(x)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((), dtype="float32"):
gv = R.call_tir(Expected.max, (x,), R.Tensor((), dtype="float32"))
return gv

@T.prim_func(private=True, s_tir=True)
def max(x: T.Buffer((), "float32"), x_red: T.Buffer((), "float32")):
T.func_attr({"tirx.noalias": True})
with T.sblock("x_red"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(x[()])
T.writes(x_red[()])
x_red[()] = x[()]
# fmt: on

mod = LegalizeOps()(Max)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading