Skip to content

[Bug] Relax ONNX Min and ArgMin mishandle NaN semantics #19540

@ALinrunrun

Description

@ALinrunrun

Expected behavior

TVM Relax should execute ONNX Min and ArgMin consistently with ONNX Runtime when inputs contain NaN.

For Min, ONNX Runtime propagates NaN when either input element is NaN.

For ArgMin, TVM should return the same index as ONNX Runtime for inputs containing NaN.

Actual behavior

TVM Relax produces different results from ONNX Runtime:

Min ORT: [nan, nan, 4.0, nan]
Min TVM: [7.0, nan, 4.0, nan]

ArgMin ORT: 1
ArgMin TVM: 3

For Min, the first element differs: ONNX Runtime returns NaN for min(NaN, 7.0), while TVM returns 7.0.

For ArgMin, ONNX Runtime selects index 1 for the input [8.0, NaN, 3.0, 1.0, 5.0], while TVM returns index 3.

Environment

TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux

Steps to reproduce

import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx


def run_tvm(model, feeds):
    mod = from_onnx(model, keep_params_in_input=False)

    with tvm.transform.PassContext(opt_level=3):
        ex = tvm.compile(mod, target=tvm.target.Target("llvm"))

    vm = relax.VirtualMachine(ex, tvm.cpu())

    out = vm["main"](
        *[tvm.runtime.tensor(v, tvm.cpu()) for v in feeds.values()]
    )

    return (out[0] if isinstance(out, (list, tuple)) else out).numpy()


node = helper.make_node("Min", ["a", "b"], ["y"])

graph = helper.make_graph(
    [node],
    "g",
    [
        helper.make_tensor_value_info("a", TensorProto.FLOAT, [4]),
        helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
    ],
    [helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
)

model_min = helper.make_model(
    graph,
    opset_imports=[helper.make_opsetid("", 17)],
)
model_min.ir_version = 9

a = np.array([np.nan, 12.0, 4.0, np.nan], dtype=np.float32)
b = np.array([7.0, np.nan, 9.0, np.nan], dtype=np.float32)

ort_min = ort.InferenceSession(
    model_min.SerializeToString(),
    providers=["CPUExecutionProvider"],
).run(None, {"a": a, "b": b})[0]

tvm_min = run_tvm(model_min, {"a": a, "b": b})

print("Min ORT:", ort_min.tolist())
print("Min TVM:", tvm_min.tolist())


node = helper.make_node("ArgMin", ["x"], ["y"], axis=0, keepdims=0)

graph = helper.make_graph(
    [node],
    "g",
    [helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])],
    [helper.make_tensor_value_info("y", TensorProto.INT64, [])],
)

model_argmin = helper.make_model(
    graph,
    opset_imports=[helper.make_opsetid("", 17)],
)
model_argmin.ir_version = 9

x = np.array([8.0, np.nan, 3.0, 1.0, 5.0], dtype=np.float32)

ort_argmin = int(
    ort.InferenceSession(
        model_argmin.SerializeToString(),
        providers=["CPUExecutionProvider"],
    ).run(None, {"x": x})[0]
)

tvm_argmin = int(run_tvm(model_argmin, {"x": x}))

print("ArgMin ORT:", ort_argmin)
print("ArgMin TVM:", tvm_argmin)

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions