diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b82fceff1d6c..a92e530af6cc 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4203,6 +4203,17 @@ def _argreduce_select_last_index(bb, data, axis, keepdims, op): return relax.op.subtract(offset, flipped_idx) +def _argreduce_sanitize_nan(bb, data, *, for_min): + """Match ONNX Runtime ArgMax/ArgMin behavior by making NaN win comparisons.""" + dtype = data.struct_info.dtype + if not _relax_dtype_is_floating_point(dtype): + return data + replacement = -_np.inf if for_min else _np.inf + return bb.emit( + relax.op.where(relax.op.isnan(data), relax.const(replacement, dtype), data) + ) + + class ArgMax(OnnxOpConverter): """Converts an onnx ArgMax node into an equivalent Relax expression.""" @@ -4218,19 +4229,19 @@ def _check_attrs(cls, data, attr, shift_axis=True): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=False) axis, keepdims = cls._check_attrs(data, attr, False) return relax.op.argmax(data, axis, keepdims) @classmethod def _impl_v11(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=False) axis, keepdims = cls._check_attrs(data, attr) return relax.op.argmax(data, axis, keepdims) @classmethod def _impl_v12(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=False) axis, keepdims = cls._check_attrs(data, attr) select_last_index = attr.get("select_last_index", False) if select_last_index: @@ -4253,19 +4264,19 @@ def _check_attrs(cls, data, attr, shift_axis=True): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=True) axis, keepdims = cls._check_attrs(data, attr, False) return relax.op.argmin(data, axis, keepdims) @classmethod def _impl_v11(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=True) axis, keepdims = cls._check_attrs(data, attr) return relax.op.argmin(data, axis, keepdims) @classmethod def _impl_v12(cls, bb, inputs, attr, params): - data = inputs[0] + data = _argreduce_sanitize_nan(bb, inputs[0], for_min=True) axis, keepdims = cls._check_attrs(data, attr) select_last_index = attr.get("select_last_index", False) if select_last_index: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7ee10993a4e9..47c12e8b7c44 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2860,6 +2860,87 @@ def verify_arg_min_max(input_dim, in_dtype, op_name="ArgMax", axis=None, keepdim verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims) +def _make_arg_min_max_model( + op_name, data_shape, out_shape, axis, keepdims, select_last_index=0 +): + node = helper.make_node( + op_name, + inputs=["data"], + outputs=["out"], + axis=axis, + keepdims=keepdims, + select_last_index=select_last_index, + ) + graph = helper.make_graph( + [node], + "arg_min_max_nan_test", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, list(data_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, out_shape)], + ) + return helper.make_model(graph, producer_name="arg_min_max_nan_test") + + +@pytest.mark.parametrize("op_name", ["ArgMax", "ArgMin"]) +def test_arg_min_max_nan_matches_ort_indices(op_name): + data = np.array( + [ + [2.0, np.nan, 7.0, 4.0, 1.0], + [np.nan, 2.0, 7.0, 4.0, 1.0], + [2.0, 4.0, 7.0, 1.0, np.nan], + ], + dtype=np.float32, + ) + expected = np.array([1, 0, 4], dtype=np.int64) + numpy_result = np.argmax(data, axis=1) if op_name == "ArgMax" else np.argmin(data, axis=1) + np.testing.assert_array_equal(numpy_result, expected) + + model = _make_arg_min_max_model(op_name, data.shape, [3], axis=1, keepdims=0) + check_correctness(model, inputs={"data": data}, opset=12) + + +@pytest.mark.parametrize("op_name", ["ArgMax", "ArgMin"]) +def test_arg_min_max_nan_keepdims_and_all_nan(op_name): + data = np.array( + [ + [[np.nan, np.nan, np.nan], [5.0, np.nan, 1.0]], + [[2.0, 3.0, np.nan], [np.nan, -1.0, -2.0]], + ], + dtype=np.float32, + ) + model = _make_arg_min_max_model(op_name, data.shape, [2, 2, 1], axis=2, keepdims=1) + check_correctness(model, inputs={"data": data}, opset=12) + + +@pytest.mark.parametrize("op_name", ["ArgMax", "ArgMin"]) +def test_arg_min_max_nan_select_last_index(op_name): + data = np.array( + [ + [[np.nan, 2.0, np.nan, 1.0], [np.nan, np.nan, np.nan, np.nan]], + [[5.0, np.nan, 1.0, np.nan], [4.0, 3.0, 2.0, 1.0]], + ], + dtype=np.float32, + ) + model = _make_arg_min_max_model( + op_name, + data.shape, + [2, 2], + axis=2, + keepdims=0, + select_last_index=1, + ) + check_correctness(model, inputs={"data": data}, opset=12) + + +@pytest.mark.parametrize("op_name", ["ArgMax", "ArgMin"]) +def test_arg_min_max_finite_regression(op_name): + data = np.array( + [[2.0, 4.0, 7.0, 1.0, 5.0], [3.0, -2.0, 8.0, 6.0, 0.0]], + dtype=np.float32, + ) + model = _make_arg_min_max_model(op_name, data.shape, [2], axis=1, keepdims=0) + check_correctness(model, inputs={"data": data}, opset=12) + + @pytest.mark.parametrize("axis", [-1, 0, 1]) @pytest.mark.parametrize("largest", [True, False]) def test_topk(axis: int, largest: int):