Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,10 +1762,13 @@ def _impl_v1(cls, bb, inputs, attr, params):
strides = attr.get("strides", [1] * spatial_dims)
dilations = attr.get("dilations", [1] * spatial_dims)
output_padding = attr.get("output_padding", [0] * spatial_dims)
groups = attr.get("group", 1)
weight_shape = inputs[1].struct_info.shape
out_channels = weight_shape.values[1] * groups
if "kernel_shape" in attr:
kernel_shape = list(attr["kernel_shape"])
else:
kernel_shape = [int(s) for s in inputs[1].struct_info.shape.values[2:]]
kernel_shape = [int(s) for s in weight_shape.values[2:]]

# Resolve `auto_pad` per ONNX ConvTranspose spec. Unlike Conv, the spec
# derives `pads` from `output_shape`/`strides` when auto_pad is SAME_*,
Expand Down Expand Up @@ -1814,13 +1817,42 @@ def _impl_v1(cls, bb, inputs, attr, params):
padding=attr.get("pads", 0),
output_padding=output_padding,
dilation=dilations,
groups=attr.get("group", 1),
groups=groups,
data_layout=data_layout,
kernel_layout=kernel_layout,
)

if inputs[2] is not None:
bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2))
bias_shape = inputs[2].struct_info.shape
if hasattr(inputs[2].struct_info, "ndim"):
bias_ndim = inputs[2].struct_info.ndim
else:
bias_ndim = len(bias_shape)
Comment on lines +1827 to +1830
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using hasattr(inputs[2].struct_info, "ndim") to check for the presence of ndim is safe, but if inputs[2].struct_info does not have ndim and bias_shape is None (which can happen if the shape is completely dynamic/unknown), calling len(bias_shape) will raise a TypeError. We should add a fallback or check to ensure bias_shape is not None before calling len() on it.

Suggested change
if hasattr(inputs[2].struct_info, "ndim"):
bias_ndim = inputs[2].struct_info.ndim
else:
bias_ndim = len(bias_shape)
if hasattr(inputs[2].struct_info, "ndim") and inputs[2].struct_info.ndim is not None:
bias_ndim = inputs[2].struct_info.ndim
elif bias_shape is not None:
bias_ndim = len(bias_shape)
else:
bias_ndim = -1

if bias_ndim != 1:
raise ValueError(f"ConvTranspose bias must be a 1D tensor, but got ndim={bias_ndim}")

def _as_static_int(dim):
try:
return int(dim)
except (TypeError, ValueError, TVMError):
return None
Comment on lines +1834 to +1838
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _as_static_int helper catches TVMError, but in some environments or FFI boundaries, TVM might raise a standard RuntimeError or other custom exceptions. To make this helper more robust and prevent unexpected compilation failures, we should also catch RuntimeError.

Suggested change
def _as_static_int(dim):
try:
return int(dim)
except (TypeError, ValueError, TVMError):
return None
def _as_static_int(dim):
try:
return int(dim)
except (TypeError, ValueError, RuntimeError, TVMError):
return None


if isinstance(bias_shape, relax.ShapeExpr):
bias_channels = bias_shape.values[0]
static_bias_channels = _as_static_int(bias_channels)
static_out_channels = _as_static_int(out_channels)
if (
static_bias_channels is not None
and static_out_channels is not None
and static_bias_channels != static_out_channels
):
raise ValueError(
"ConvTranspose bias length must equal output channels "
f"(weight.shape[1] * group = {static_out_channels}), "
f"but got {static_bias_channels}."
)

bias = relax.op.reshape(inputs[2], [1, out_channels] + [1] * (ndim - 2))
conv_out = relax.op.add(conv_out, bias)

return conv_out
Expand Down
42 changes: 33 additions & 9 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,9 +1855,9 @@ def _verify_conv(input_shape, weight_shape):
@pytest.mark.parametrize("pad", [0, 2])
@pytest.mark.parametrize("output_pad", [0, 1])
def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool, output_pad: int):
def _verify_conv_transpose(input_shape, weight_shape):
def _verify_conv_transpose(input_shape, weight_shape, group=1):
nd = len(weight_shape) - 2
output_shape = [input_shape[0], weight_shape[0]] + [
output_shape = [input_shape[0], weight_shape[1] * group] + [
(input_shape[i] - 1) * stride
- 2 * pad
+ dilation * (weight_shape[i] - 1)
Expand All @@ -1874,7 +1874,7 @@ def _verify_conv_transpose(input_shape, weight_shape):
dilations=[dilation] * nd,
pads=[pad] * nd * 2,
output_padding=[output_pad] * nd,
group=input_shape[1] // weight_shape[1],
group=group,
)
graph = helper.make_graph(
[conv_node],
Expand All @@ -1891,14 +1891,38 @@ def _verify_conv_transpose(input_shape, weight_shape):
check_correctness(model, atol=1e-4)

# ConvTranspose1D
_verify_conv_transpose([3, 4, 32], [4, 4, 3])
_verify_conv_transpose([3, 4, 32], [4, 2, 3]) # group=2
_verify_conv_transpose([3, 4, 32], [4, 6, 3])
_verify_conv_transpose([3, 4, 32], [4, 3, 3], group=2)
# ConvTranspose2D
_verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3])
_verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2
_verify_conv_transpose([3, 4, 32, 32], [4, 6, 3, 3])
_verify_conv_transpose([3, 4, 32, 32], [4, 3, 3, 3], group=2)
# ConvTranspose3D
_verify_conv_transpose([3, 4, 12, 12, 12], [4, 4, 3, 3, 3])
_verify_conv_transpose([3, 4, 12, 12, 12], [4, 2, 3, 3, 3]) # group=2
_verify_conv_transpose([3, 4, 12, 12, 12], [4, 6, 3, 3, 3])
_verify_conv_transpose([3, 4, 12, 12, 12], [4, 3, 3, 3, 3], group=2)


def test_conv_transpose_invalid_bias_channel_count():
conv_node = helper.make_node(
"ConvTranspose",
inputs=["x", "w", "b"],
outputs=["y"],
pads=[0, 0, 0, 0],
group=2,
)
graph = helper.make_graph(
[conv_node],
"conv_transpose_invalid_bias_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4, 5, 5]),
helper.make_tensor_value_info("w", TensorProto.FLOAT, [4, 3, 3, 3]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [5]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 6, 7, 7])],
)

model = helper.make_model(graph, producer_name="conv_transpose_invalid_bias_test")
with pytest.raises(ValueError, match="ConvTranspose bias length"):
from_onnx(model, opset=14, keep_params_in_input=True)


@pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"])
Expand Down