Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion exir/pass_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -191,6 +191,11 @@ def create_arg(self, a: Argument) -> torch.fx.Node:
if not hasattr(a, "constant") or a.constant is None:
raise ExportPassBaseError(f"Cannot add {a} to graph.")
a = a.constant
elif isinstance(a, torch.SymInt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):

and add corresponding unit test for symfloat and symbool please

thank you for finding this bug btw @per

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to trigger the SymFloat and SymBool paths here, since it comes from the dynamic shape export, which implies SymInts only, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, the proposed tests doesn't trigger the bug, since the symbool/symfloat aren't part of a list argument (as the shape argument is in view_copy). The tests you linked to are properly handled already (scalars and tensors).
The only alternative I've come up with is to manually construct a symbol with a shape_env, but that feels like a really constructed way to trigger the bug:

        shape_env = ShapeEnv()
        sym_bool = shape_env.create_unbacked_symbool()
        tracer_owner = ExportPass()
        tracer = tracer_owner.tracer
        tracer.create_arg([sym_bool])

I've haven't been able to track down any operator that takes a list of bools or list of floats as argument, so IMHO it makes sense to keep the test as is, since that is what is exercised through the normal export flow. Or am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thank you for explaining and walking me through the code

if a.node.constant is not None:
return a.node.constant
else:
return a
node = super().create_arg(a)
if (
isinstance(a, torch.Tensor)
Expand Down
107 changes: 106 additions & 1 deletion exir/tests/test_dynamic_shape_propagation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from unittest import TestCase

import torch

from executorch import exir
from executorch.exir import to_edge
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
from executorch.exir.passes import (
DebugPass,
ExportPass,
HintBasedSymShapeEvalPass,
SpecPropPass,
)
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.tests.models import Repeat, TensorItem
from torch.export import export

Expand Down Expand Up @@ -67,3 +77,98 @@ def test_unbacked_symint(self):
self.assertEqual(
speclist[0].shape, [100, 100]
) # upper bound of TensorItem model


class TestSymIntViewArgs(TestCase):
class Conv1dToConv2d(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input: torch.Tensor) -> torch.Tensor:
# Use view to make sure edge view handle symint shapes correctly.
# input = input.view(input.size(0), input.size(1), input.size(2), 1) # (N, C, H, W)
# weight = torch.randn(1, 16, 3, 1) # (out_channels, in_channels, kH, kW)
# return torch.nn.functional.conv2d(input, weight)

return torch.nn.functional.conv1d(
input, torch.randn(1, 16, 3)
) # (out_channels, in_channels, kW)

def get_random_inputs(self) -> tuple[torch.Tensor]:
return (torch.randn(1, 16, 50),) # (batch_size, channels, width)

def get_dynamic_shape(self) -> tuple[dict[int, torch.export.Dim]]:
dim = torch.export.Dim("width", min=10, max=100)
return ({2: dim},)

def test_symint_viewargs(self):
eager_model = TestSymIntViewArgs.Conv1dToConv2d()
inputs = eager_model.get_random_inputs()

class TestViewCopyPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
from executorch.exir.dialects._ops import ops as exir_ops

if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)

x = args[0]
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, list(x.data.shape) + [1]),
{},
meta,
)

w = args[1]
w = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(w, list(w.data.shape) + [1]),
{},
meta,
)

new_args = (
x,
w,
args[2],
args[3] + [1], # stride
args[4] + [0], # padding
args[5] + [1], # dilation
args[6],
args[7] + [0],
args[8],
)
x = super().call_operator(
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
)
x = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(x, list(x.data.shape)[:-1]),
{},
meta,
)

return x

prog = to_edge(
export(
eager_model,
inputs,
dynamic_shapes=eager_model.get_dynamic_shape(),
strict=True,
),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform(
[SpecPropPass(), ConstraintBasedSymShapeEvalPass(), TestViewCopyPass()]
)
gm = new_prog.exported_program().graph_module
DebugPass(show_spec=True)(gm)
*_, return_node = gm.graph.nodes
speclist = return_node.meta["spec"]

self.assertEqual(len(speclist), 1)
out_spec = speclist[0]
self.assertTrue(out_spec.is_upper_bound_tensor)
self.assertEqual(out_spec.shape, [1, 1, 98])
Loading