Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,15 +925,35 @@ def convert_range(self, op):

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]

expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]
def get_scalar_value(tensor):
if self.has_expr(tensor.tensor_idx):
expr = self.get_expr(tensor.tensor_idx)
if isinstance(expr, relax.Constant):
value = expr.data.numpy()
else:
# relax.op.arange currently expects scalar-like values here.
# Keep dynamic scalar RANGE explicit until frontend support is added.
raise tvm.error.OpNotImplemented(
"TFLite RANGE with dynamic scalar inputs is not supported in Relax frontend yet."
)
else:
value = self.get_tensor_value(tensor)

# TFLite RANGE operands are scalar tensors in the flatbuffer.
assert value.size == 1, "RANGE scalar input must have exactly one element"
return value.item()

start_value = get_scalar_value(start)
limit_value = get_scalar_value(limit)
delta_value = get_scalar_value(delta)

# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

out = relax.op.arange(expressions[0], expressions[1], expressions[2], out_type)
out = relax.op.arange(start_value, limit_value, delta_value, out_type)

return out

Expand All @@ -942,6 +962,7 @@ def convert_shape(self, op):

from tflite.BuiltinOptions import BuiltinOptions
from tflite.ShapeOptions import ShapeOptions
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
Expand All @@ -951,7 +972,10 @@ def convert_shape(self, op):
shape_options = ShapeOptions()
shape_options.Init(op_options.Bytes, op_options.Pos)

out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))
# SHAPE must materialize as a tensor output in Relax, not just symbolic shape info.
out = relax.op.shape_to_tensor(relax.op.shape_of(self.get_tensor_expr(input_tensors[0])))
if shape_options.OutType() == TensorType.INT32:
out = relax.op.astype(out, "int32")

return out

Expand Down Expand Up @@ -4055,7 +4079,7 @@ def _input_type(model):
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
# TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE models).
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
Expand Down Expand Up @@ -4209,7 +4233,9 @@ def func(self, data):
op_converter.convert_op_to_relax()

# params and outputs
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
# Resolve outputs through tensor wrappers so constant/prefetched outputs are handled.
output_tensors = op_converter.get_tensors(model_outputs)
outputs = [op_converter.get_tensor_expr(tensor) for tensor in output_tensors]
outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs)
output_var = bb.emit_output(outputs)

Expand Down
72 changes: 72 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,78 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 2, 15), dtype="f
verify(Reshape, Expected)


@pytest.mark.parametrize(
"input_shape, out_type",
[
((2, 3, 4), tf.int32),
((5,), tf.int64),
((1, 1, 1, 1), tf.int32),
((), tf.int32),
((0, 3), tf.int64),
],
)
def test_shape(input_shape, out_type):
"""SHAPE conversion for static-rank non-quantized tensors."""

class Shape(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
def func(self, x):
return tf.shape(x, out_type=out_type)

verify(Shape)


def test_shape_dynamic_dim():
"""SHAPE conversion with a dynamic input dimension."""

class ShapeDynamic(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)])
def func(self, x):
return tf.shape(x, out_type=tf.int32)

verify(ShapeDynamic)


@pytest.mark.parametrize(
"start, limit, delta, dtype",
[
(0, 8, 2, tf.int32),
(1, 9, 2, tf.int64),
(0.0, 1.0, 0.2, tf.float32),
(8, 0, -2, tf.int32),
(0, 0, 1, tf.int32),
(0, 7, 2, tf.int32),
(0.0, -1.0, -0.25, tf.float32),
],
)
def test_range(start, limit, delta, dtype):
"""RANGE conversion with non-quantized constant scalar bounds."""

class Range(tf.Module):
@tf.function(input_signature=[])
def func(self):
return tf.range(start, limit, delta, dtype=dtype)

verify(Range)


def test_range_dynamic_scalar_inputs_not_supported():
"""RANGE conversion currently rejects dynamic scalar inputs."""

class RangeDynamic(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32),
]
)
def func(self, start, limit, delta):
return tf.range(start, limit, delta, dtype=tf.int32)

with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"):
verify(RangeDynamic)

def test_tile_ir():
"""TILE conversion with explicit Relax IR structural check."""

Expand Down
Loading