From 51909dcc2bf83b510fe38bd74c1986cf1263c078 Mon Sep 17 00:00:00 2001 From: rknastenka Date: Mon, 13 Apr 2026 13:44:08 +0300 Subject: [PATCH 1/3] [Frontend][TFLite] Add test coverage for SHAPE and RANGE operators --- .../relax/frontend/tflite/tflite_frontend.py | 42 +++++++++-- tests/python/relax/test_frontend_tflite.py | 73 +++++++++++++++++++ 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 9c99e98e01c8..16e3e8a07349 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -925,15 +925,37 @@ def convert_range(self, op): start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] - expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]] - + def get_scalar_value(tensor): + if self.has_expr(tensor.tensor_idx): + expr = self.get_expr(tensor.tensor_idx) + if isinstance(expr, relax.Constant): + value = expr.data.numpy() + else: + # relax.op.arange currently expects scalar-like values here. + # Keep dynamic scalar RANGE explicit until frontend support is added. + raise tvm.error.OpNotImplemented( + "TFLite RANGE with dynamic scalar inputs is not supported in Relax frontend yet." + ) + else: + value = self.get_tensor_value(tensor) + + if isinstance(value, np.ndarray): + # TFLite RANGE operands are scalar tensors in the flatbuffer. + assert value.size == 1, "RANGE scalar input must have exactly one element" + return value.item() + return value + + start_value = get_scalar_value(start) + limit_value = get_scalar_value(limit) + delta_value = get_scalar_value(delta) + # out type inference if delta.tensor.Type() == TensorType.FLOAT32: out_type = self.get_tensor_type_str(delta.tensor.Type()) else: out_type = self.get_tensor_type_str(start.tensor.Type()) - out = relax.op.arange(expressions[0], expressions[1], expressions[2], out_type) + out = relax.op.arange(start_value, limit_value, delta_value, out_type) return out @@ -942,6 +964,7 @@ def convert_shape(self, op): from tflite.BuiltinOptions import BuiltinOptions from tflite.ShapeOptions import ShapeOptions + from tflite.TensorType import TensorType input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -951,7 +974,12 @@ def convert_shape(self, op): shape_options = ShapeOptions() shape_options.Init(op_options.Bytes, op_options.Pos) - out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0])) + # SHAPE must materialize as a tensor output in Relax, not just symbolic shape info. + out = relax.op.shape_to_tensor(relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))) + if shape_options.OutType() == TensorType.INT32: + out = relax.op.astype(out, "int32") + elif shape_options.OutType() == TensorType.INT64: + out = relax.op.astype(out, "int64") return out @@ -4067,7 +4095,7 @@ def _input_type(model): for subgraph_index in range(subgraph_count): subgraph = model.Subgraphs(subgraph_index) inputs_count = subgraph.InputsLength() - assert inputs_count >= 1 + # TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE models). for input_index in range(inputs_count): input_ = subgraph.Inputs(input_index) assert subgraph.TensorsLength() > input_ @@ -4221,7 +4249,9 @@ def func(self, data): op_converter.convert_op_to_relax() # params and outputs - outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] + # Resolve outputs through tensor wrappers so constant/prefetched outputs are handled. + output_tensors = op_converter.get_tensors(model_outputs) + outputs = [op_converter.get_tensor_expr(tensor) for tensor in output_tensors] outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs) output_var = bb.emit_output(outputs) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 58af46cbc949..baeb3108d317 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -279,6 +279,79 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 2, 15), dtype="f verify(Reshape, Expected) +@pytest.mark.parametrize( + "input_shape, out_type", + [ + ((2, 3, 4), tf.int32), + ((5,), tf.int64), + ((1, 1, 1, 1), tf.int32), + ((), tf.int32), + ((0, 3), tf.int64), + ], +) +def test_shape(input_shape, out_type): + """SHAPE conversion for static-rank non-quantized tensors.""" + + class Shape(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) + def func(self, x): + return tf.shape(x, out_type=out_type) + + verify(Shape) + + +def test_shape_dynamic_dim(): + """SHAPE conversion with a dynamic input dimension.""" + + class ShapeDynamic(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)]) + def func(self, x): + return tf.shape(x, out_type=tf.int32) + + verify(ShapeDynamic) + + +@pytest.mark.parametrize( + "start, limit, delta, dtype", + [ + (0, 8, 2, tf.int32), + (1, 9, 2, tf.int64), + (0.0, 1.0, 0.2, tf.float32), + (8, 0, -2, tf.int32), + (0, 0, 1, tf.int32), + (0, 7, 2, tf.int32), + (0.0, -1.0, -0.25, tf.float32), + ], +) +def test_range(start, limit, delta, dtype): + """RANGE conversion with non-quantized constant scalar bounds.""" + + class Range(tf.Module): + @tf.function(input_signature=[]) + def func(self): + return tf.range(start, limit, delta, dtype=dtype) + + verify(Range) + + +def test_range_dynamic_scalar_inputs_not_supported(): + """RANGE conversion currently rejects dynamic scalar inputs.""" + + class RangeDynamic(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.int32), + ] + ) + def func(self, start, limit, delta): + return tf.range(start, limit, delta, dtype=tf.int32) + + with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"): + verify(RangeDynamic) + + def test_concat_v2(): class ConcatV2(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) From c0c835d657807b859ea2d7caeeff040382f6da78 Mon Sep 17 00:00:00 2001 From: Bana Date: Mon, 13 Apr 2026 14:02:55 +0300 Subject: [PATCH 2/3] Update python/tvm/relax/frontend/tflite/tflite_frontend.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 16e3e8a07349..f2b1bdfc58f6 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -978,8 +978,6 @@ def convert_shape(self, op): out = relax.op.shape_to_tensor(relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))) if shape_options.OutType() == TensorType.INT32: out = relax.op.astype(out, "int32") - elif shape_options.OutType() == TensorType.INT64: - out = relax.op.astype(out, "int64") return out From 6aa4fe56897fdf5de19e01cc0ebcb989dab4e7d6 Mon Sep 17 00:00:00 2001 From: rknastenka Date: Mon, 13 Apr 2026 14:04:53 +0300 Subject: [PATCH 3/3] resolves gemini suggestion --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 16e3e8a07349..79a9b41027d5 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -939,11 +939,9 @@ def get_scalar_value(tensor): else: value = self.get_tensor_value(tensor) - if isinstance(value, np.ndarray): - # TFLite RANGE operands are scalar tensors in the flatbuffer. - assert value.size == 1, "RANGE scalar input must have exactly one element" - return value.item() - return value + # TFLite RANGE operands are scalar tensors in the flatbuffer. + assert value.size == 1, "RANGE scalar input must have exactly one element" + return value.item() start_value = get_scalar_value(start) limit_value = get_scalar_value(limit)