From cbc7f0b1d82167562eff122802504eb5b92ffde4 Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:56:25 -0700 Subject: [PATCH] fix: ONNX ConvTranspose bias broadcast mismatch in Relax frontend --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 38 +++++++++++++++-- tests/python/relax/test_frontend_onnx.py | 42 +++++++++++++++---- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b82fceff1d6c..d3c6b95ae093 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1762,10 +1762,13 @@ def _impl_v1(cls, bb, inputs, attr, params): strides = attr.get("strides", [1] * spatial_dims) dilations = attr.get("dilations", [1] * spatial_dims) output_padding = attr.get("output_padding", [0] * spatial_dims) + groups = attr.get("group", 1) + weight_shape = inputs[1].struct_info.shape + out_channels = weight_shape.values[1] * groups if "kernel_shape" in attr: kernel_shape = list(attr["kernel_shape"]) else: - kernel_shape = [int(s) for s in inputs[1].struct_info.shape.values[2:]] + kernel_shape = [int(s) for s in weight_shape.values[2:]] # Resolve `auto_pad` per ONNX ConvTranspose spec. Unlike Conv, the spec # derives `pads` from `output_shape`/`strides` when auto_pad is SAME_*, @@ -1814,13 +1817,42 @@ def _impl_v1(cls, bb, inputs, attr, params): padding=attr.get("pads", 0), output_padding=output_padding, dilation=dilations, - groups=attr.get("group", 1), + groups=groups, data_layout=data_layout, kernel_layout=kernel_layout, ) if inputs[2] is not None: - bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2)) + bias_shape = inputs[2].struct_info.shape + if hasattr(inputs[2].struct_info, "ndim"): + bias_ndim = inputs[2].struct_info.ndim + else: + bias_ndim = len(bias_shape) + if bias_ndim != 1: + raise ValueError(f"ConvTranspose bias must be a 1D tensor, but got ndim={bias_ndim}") + + def _as_static_int(dim): + try: + return int(dim) + except (TypeError, ValueError, TVMError): + return None + + if isinstance(bias_shape, relax.ShapeExpr): + bias_channels = bias_shape.values[0] + static_bias_channels = _as_static_int(bias_channels) + static_out_channels = _as_static_int(out_channels) + if ( + static_bias_channels is not None + and static_out_channels is not None + and static_bias_channels != static_out_channels + ): + raise ValueError( + "ConvTranspose bias length must equal output channels " + f"(weight.shape[1] * group = {static_out_channels}), " + f"but got {static_bias_channels}." + ) + + bias = relax.op.reshape(inputs[2], [1, out_channels] + [1] * (ndim - 2)) conv_out = relax.op.add(conv_out, bias) return conv_out diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7ee10993a4e9..aafcfd4e9237 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1855,9 +1855,9 @@ def _verify_conv(input_shape, weight_shape): @pytest.mark.parametrize("pad", [0, 2]) @pytest.mark.parametrize("output_pad", [0, 1]) def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool, output_pad: int): - def _verify_conv_transpose(input_shape, weight_shape): + def _verify_conv_transpose(input_shape, weight_shape, group=1): nd = len(weight_shape) - 2 - output_shape = [input_shape[0], weight_shape[0]] + [ + output_shape = [input_shape[0], weight_shape[1] * group] + [ (input_shape[i] - 1) * stride - 2 * pad + dilation * (weight_shape[i] - 1) @@ -1874,7 +1874,7 @@ def _verify_conv_transpose(input_shape, weight_shape): dilations=[dilation] * nd, pads=[pad] * nd * 2, output_padding=[output_pad] * nd, - group=input_shape[1] // weight_shape[1], + group=group, ) graph = helper.make_graph( [conv_node], @@ -1891,14 +1891,38 @@ def _verify_conv_transpose(input_shape, weight_shape): check_correctness(model, atol=1e-4) # ConvTranspose1D - _verify_conv_transpose([3, 4, 32], [4, 4, 3]) - _verify_conv_transpose([3, 4, 32], [4, 2, 3]) # group=2 + _verify_conv_transpose([3, 4, 32], [4, 6, 3]) + _verify_conv_transpose([3, 4, 32], [4, 3, 3], group=2) # ConvTranspose2D - _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3]) - _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2 + _verify_conv_transpose([3, 4, 32, 32], [4, 6, 3, 3]) + _verify_conv_transpose([3, 4, 32, 32], [4, 3, 3, 3], group=2) # ConvTranspose3D - _verify_conv_transpose([3, 4, 12, 12, 12], [4, 4, 3, 3, 3]) - _verify_conv_transpose([3, 4, 12, 12, 12], [4, 2, 3, 3, 3]) # group=2 + _verify_conv_transpose([3, 4, 12, 12, 12], [4, 6, 3, 3, 3]) + _verify_conv_transpose([3, 4, 12, 12, 12], [4, 3, 3, 3, 3], group=2) + + +def test_conv_transpose_invalid_bias_channel_count(): + conv_node = helper.make_node( + "ConvTranspose", + inputs=["x", "w", "b"], + outputs=["y"], + pads=[0, 0, 0, 0], + group=2, + ) + graph = helper.make_graph( + [conv_node], + "conv_transpose_invalid_bias_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4, 5, 5]), + helper.make_tensor_value_info("w", TensorProto.FLOAT, [4, 3, 3, 3]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [5]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 6, 7, 7])], + ) + + model = helper.make_model(graph, producer_name="conv_transpose_invalid_bias_test") + with pytest.raises(ValueError, match="ConvTranspose bias length"): + from_onnx(model, opset=14, keep_params_in_input=True) @pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"])