Min ORT: [nan, nan, 4.0, nan]
Min TVM: [7.0, nan, 4.0, nan]
ArgMin ORT: 1
ArgMin TVM: 3
For Min, the first element differs: ONNX Runtime returns NaN for min(NaN, 7.0), while TVM returns 7.0.
For ArgMin, ONNX Runtime selects index 1 for the input [8.0, NaN, 3.0, 1.0, 5.0], while TVM returns index 3.
import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
def run_tvm(model, feeds):
mod = from_onnx(model, keep_params_in_input=False)
with tvm.transform.PassContext(opt_level=3):
ex = tvm.compile(mod, target=tvm.target.Target("llvm"))
vm = relax.VirtualMachine(ex, tvm.cpu())
out = vm["main"](
*[tvm.runtime.tensor(v, tvm.cpu()) for v in feeds.values()]
)
return (out[0] if isinstance(out, (list, tuple)) else out).numpy()
node = helper.make_node("Min", ["a", "b"], ["y"])
graph = helper.make_graph(
[node],
"g",
[
helper.make_tensor_value_info("a", TensorProto.FLOAT, [4]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
)
model_min = helper.make_model(
graph,
opset_imports=[helper.make_opsetid("", 17)],
)
model_min.ir_version = 9
a = np.array([np.nan, 12.0, 4.0, np.nan], dtype=np.float32)
b = np.array([7.0, np.nan, 9.0, np.nan], dtype=np.float32)
ort_min = ort.InferenceSession(
model_min.SerializeToString(),
providers=["CPUExecutionProvider"],
).run(None, {"a": a, "b": b})[0]
tvm_min = run_tvm(model_min, {"a": a, "b": b})
print("Min ORT:", ort_min.tolist())
print("Min TVM:", tvm_min.tolist())
node = helper.make_node("ArgMin", ["x"], ["y"], axis=0, keepdims=0)
graph = helper.make_graph(
[node],
"g",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])],
[helper.make_tensor_value_info("y", TensorProto.INT64, [])],
)
model_argmin = helper.make_model(
graph,
opset_imports=[helper.make_opsetid("", 17)],
)
model_argmin.ir_version = 9
x = np.array([8.0, np.nan, 3.0, 1.0, 5.0], dtype=np.float32)
ort_argmin = int(
ort.InferenceSession(
model_argmin.SerializeToString(),
providers=["CPUExecutionProvider"],
).run(None, {"x": x})[0]
)
tvm_argmin = int(run_tvm(model_argmin, {"x": x}))
print("ArgMin ORT:", ort_argmin)
print("ArgMin TVM:", tvm_argmin)
Expected behavior
TVM Relax should execute ONNX
MinandArgMinconsistently with ONNX Runtime when inputs containNaN.For
Min, ONNX Runtime propagatesNaNwhen either input element isNaN.For
ArgMin, TVM should return the same index as ONNX Runtime for inputs containingNaN.Actual behavior
TVM Relax produces different results from ONNX Runtime:
For Min, the first element differs: ONNX Runtime returns NaN for min(NaN, 7.0), while TVM returns 7.0.
For ArgMin, ONNX Runtime selects index 1 for the input [8.0, NaN, 3.0, 1.0, 5.0], while TVM returns index 3.
Environment
TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux
Steps to reproduce
Triage