diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index f71e5c564cea..b344d9361a7a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -2832,9 +2832,7 @@ def convert_batch_matmul(self, op): new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b] max_rank = max(rank_a, rank_b) - batch_shape = [ - max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2) - ] + batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)] a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])] b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])] @@ -2903,7 +2901,14 @@ def convert_depth_to_space(self, op): depth_to_space_options = DepthToSpaceOptions() depth_to_space_options.Init(op_options.Bytes, op_options.Pos) block_size = depth_to_space_options.BlockSize() - out = relax.op.nn.depth_to_space(in_expr, block_size, layout="NHWC") + + # TFLite uses NHWC layout: (N, H, W, C) -> (N, H*bs, W*bs, C/(bs*bs)) + input_shape = self.get_tensor_shape(input_tensor) + n, h, w, c = input_shape + out_c = c // (block_size**2) + out = relax.op.reshape(in_expr, (n, h, w, block_size, block_size, out_c)) + out = relax.op.permute_dims(out, [0, 1, 3, 2, 4, 5]) + out = relax.op.reshape(out, (n, h * block_size, w * block_size, out_c)) return out @@ -2924,7 +2929,17 @@ def convert_space_to_depth(self, op): space_to_depth_options = SpaceToDepthOptions() space_to_depth_options.Init(op_options.Bytes, op_options.Pos) block_size = space_to_depth_options.BlockSize() - out = relax.op.nn.space_to_depth(in_expr, block_size, layout="NHWC") + + # TFLite uses NHWC layout: (N, H, W, C) -> (N, H/bs, W/bs, C*bs*bs) + input_shape = self.get_tensor_shape(input_tensor) + n, h, w, c = input_shape + out = relax.op.reshape( + in_expr, (n, h // block_size, block_size, w // block_size, block_size, c) + ) + out = relax.op.permute_dims(out, [0, 1, 3, 2, 4, 5]) + out = relax.op.reshape( + out, (n, h // block_size, w // block_size, c * block_size * block_size) + ) return out @@ -3348,8 +3363,8 @@ def convert_nms_v5(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 6, "input tensor length should be 6" - boxes = self.get_tensor_expr(input_tensors[0]) - scores = self.get_tensor_expr(input_tensors[1]) + boxes = self.get_tensor_expr(input_tensors[0]) + scores = self.get_tensor_expr(input_tensors[1]) max_output_size = self.get_tensor_value(input_tensors[2]) iou_threshold = self.get_tensor_value(input_tensors[3]) @@ -3403,14 +3418,16 @@ def convert_nms_v5(self, op): ) selected_indices = relax.op.squeeze(nms_ret[0], axis=[0]) - selected_indices = relax.op.strided_slice(selected_indices, axes=[0], begin=[0], end=[max_output_size]) - num_valid = relax.op.reshape(nms_ret[1], []) + selected_indices = relax.op.strided_slice( + selected_indices, axes=[0], begin=[0], end=[max_output_size] + ) + num_valid = relax.op.reshape(nms_ret[1], []) # Clamp out-of-bound padded indices to prevent take() crash. num_boxes = int(self.get_tensor_shape(input_tensors[0])[0]) safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1) selected_scores = relax.op.take(scores, safe_indices, axis=0) - + out = relax.Tuple([selected_indices, selected_scores, num_valid]) return out diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index c0de33748efd..02282f3d41c9 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1027,9 +1027,7 @@ def _verify_nms_v5(mod, tf_func, boxes_np, scores_np): if "CI_ENV_NIGHTLY" not in os.environ: return - tf_indices, tf_scores, tf_valid = tf_func( - tf.constant(boxes_np), tf.constant(scores_np) - ) + tf_indices, tf_scores, tf_valid = tf_func(tf.constant(boxes_np), tf.constant(scores_np)) n_valid = int(tf_valid.numpy()) tgt = tvm.target.Target("llvm") @@ -1100,51 +1098,75 @@ def _make_valid_boxes(rng, n): _NMS_V5_CASES = [ pytest.param( - 6, 3, 0.5, 0.0, - np.array([ - [0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0], - [0.0, 0.1, 1.0, 1.1], - [0.0, 0.0, 1.0, 0.9], - [0.5, 0.5, 1.5, 1.5], - [0.0, 0.0, 0.3, 0.3], - ], dtype=np.float32), + 6, + 3, + 0.5, + 0.0, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, 0.0, 1.0, 0.9], + [0.5, 0.5, 1.5, 1.5], + [0.0, 0.0, 0.3, 0.3], + ], + dtype=np.float32, + ), np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), id="basic", ), pytest.param( - 8, 4, 0.5, 0.4, + 8, + 4, + 0.5, + 0.4, _make_valid_boxes(np.random.default_rng(42), 8), np.random.default_rng(42).random(8, dtype=np.float32), id="score_threshold", ), pytest.param( - 5, 3, 0.5, 0.99, + 5, + 3, + 0.5, + 0.99, _make_valid_boxes(np.random.default_rng(0), 5), np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32), id="all_suppressed", ), pytest.param( - 6, 6, 0.1, 0.0, - np.array([ - [0.0, 0.0, 0.4, 0.4], - [0.5, 0.5, 0.9, 0.9], - [0.1, 0.1, 0.5, 0.5], - [0.6, 0.6, 1.0, 1.0], - [0.0, 0.5, 0.4, 0.9], - [0.5, 0.0, 0.9, 0.4], - ], dtype=np.float32), + 6, + 6, + 0.1, + 0.0, + np.array( + [ + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.9, 0.9], + [0.1, 0.1, 0.5, 0.5], + [0.6, 0.6, 1.0, 1.0], + [0.0, 0.5, 0.4, 0.9], + [0.5, 0.0, 0.9, 0.4], + ], + dtype=np.float32, + ), np.array([0.9, 0.85, 0.7, 0.65, 0.6, 0.55], dtype=np.float32), id="iou_threshold", ), pytest.param( - 4, 10, 0.5, 0.0, - np.array([ - [0.0, 0.0, 0.3, 0.3], - [0.5, 0.5, 0.8, 0.8], - [0.1, 0.1, 0.4, 0.4], - [0.6, 0.6, 0.9, 0.9], - ], dtype=np.float32), + 4, + 10, + 0.5, + 0.0, + np.array( + [ + [0.0, 0.0, 0.3, 0.3], + [0.5, 0.5, 0.8, 0.8], + [0.1, 0.1, 0.4, 0.4], + [0.6, 0.6, 0.9, 0.9], + ], + dtype=np.float32, + ), np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32), id="max_output_size_larger_than_boxes", ), @@ -1185,7 +1207,9 @@ def test_nms_v5_ir(): assert f"R.Tensor(({max_output_size},)" in ir -def _make_resize_expected(input_shape, output_size, method, coordinate_transformation_mode, rounding_method): +def _make_resize_expected( + input_shape, output_size, method, coordinate_transformation_mode, rounding_method +): """Build an Expected IRModule programmatically to avoid TVMScript variable scope limitations.""" bb = relax.BlockBuilder() x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) @@ -1215,13 +1239,48 @@ def _make_resize_expected(input_shape, output_size, method, coordinate_transform @pytest.mark.parametrize( "input_shape, output_size, tf_op, coordinate_transformation_mode", [ - ((1, 4, 4, 1), [8, 8], lambda x: tf.image.resize(x, [8, 8], method="bilinear"), "half_pixel"), - ((1, 8, 8, 3), [4, 4], lambda x: tf.image.resize(x, [4, 4], method="bilinear"), "half_pixel"), - ((1, 4, 4, 1), [7, 7], lambda x: tf.compat.v1.image.resize_bilinear(x, [7, 7], align_corners=True), "align_corners"), - ((1, 4, 4, 2), [8, 8], lambda x: tf.compat.v1.image.resize_bilinear(x, [8, 8], half_pixel_centers=True), "half_pixel"), - ((2, 6, 6, 16), [12, 12], lambda x: tf.image.resize(x, [12, 12], method="bilinear"), "half_pixel"), - ((1, 5, 5, 3), [5, 5], lambda x: tf.image.resize(x, [5, 5], method="bilinear"), "half_pixel"), - ((1, 4, 8, 1), [8, 16], lambda x: tf.image.resize(x, [8, 16], method="bilinear"), "half_pixel"), + ( + (1, 4, 4, 1), + [8, 8], + lambda x: tf.image.resize(x, [8, 8], method="bilinear"), + "half_pixel", + ), + ( + (1, 8, 8, 3), + [4, 4], + lambda x: tf.image.resize(x, [4, 4], method="bilinear"), + "half_pixel", + ), + ( + (1, 4, 4, 1), + [7, 7], + lambda x: tf.compat.v1.image.resize_bilinear(x, [7, 7], align_corners=True), + "align_corners", + ), + ( + (1, 4, 4, 2), + [8, 8], + lambda x: tf.compat.v1.image.resize_bilinear(x, [8, 8], half_pixel_centers=True), + "half_pixel", + ), + ( + (2, 6, 6, 16), + [12, 12], + lambda x: tf.image.resize(x, [12, 12], method="bilinear"), + "half_pixel", + ), + ( + (1, 5, 5, 3), + [5, 5], + lambda x: tf.image.resize(x, [5, 5], method="bilinear"), + "half_pixel", + ), + ( + (1, 4, 8, 1), + [8, 16], + lambda x: tf.image.resize(x, [8, 16], method="bilinear"), + "half_pixel", + ), ], ) def test_resize_bilinear(input_shape, output_size, tf_op, coordinate_transformation_mode): @@ -1230,28 +1289,74 @@ class ResizeBilinear(tf.Module): def func(self, x): return tf_op(x) - expected = _make_resize_expected(input_shape, output_size, "linear", coordinate_transformation_mode, "") + expected = _make_resize_expected( + input_shape, output_size, "linear", coordinate_transformation_mode, "" + ) verify(ResizeBilinear, expected) @pytest.mark.parametrize( "input_shape, output_size, tf_op, coordinate_transformation_mode, rounding_method", [ - ((1, 2, 2, 1), [4, 4], lambda x: tf.image.resize(x, [4, 4], method="nearest"), "half_pixel", "round_prefer_ceil"), - ((1, 8, 8, 3), [4, 4], lambda x: tf.image.resize(x, [4, 4], method="nearest"), "half_pixel", "round_prefer_ceil"), - ((1, 4, 4, 1), [7, 7], lambda x: tf.compat.v1.image.resize_nearest_neighbor(x, [7, 7], align_corners=True), "align_corners", ""), - ((4, 3, 3, 8), [6, 6], lambda x: tf.image.resize(x, [6, 6], method="nearest"), "half_pixel", "round_prefer_ceil"), - ((1, 4, 8, 1), [8, 16], lambda x: tf.image.resize(x, [8, 16], method="nearest"), "half_pixel", "round_prefer_ceil"), - ((1, 3, 3, 2), [3, 3], lambda x: tf.image.resize(x, [3, 3], method="nearest"), "half_pixel", "round_prefer_ceil"), + ( + (1, 2, 2, 1), + [4, 4], + lambda x: tf.image.resize(x, [4, 4], method="nearest"), + "half_pixel", + "round_prefer_ceil", + ), + ( + (1, 8, 8, 3), + [4, 4], + lambda x: tf.image.resize(x, [4, 4], method="nearest"), + "half_pixel", + "round_prefer_ceil", + ), + ( + (1, 4, 4, 1), + [7, 7], + lambda x: tf.compat.v1.image.resize_nearest_neighbor(x, [7, 7], align_corners=True), + "align_corners", + "", + ), + ( + (4, 3, 3, 8), + [6, 6], + lambda x: tf.image.resize(x, [6, 6], method="nearest"), + "half_pixel", + "round_prefer_ceil", + ), + ( + (1, 4, 8, 1), + [8, 16], + lambda x: tf.image.resize(x, [8, 16], method="nearest"), + "half_pixel", + "round_prefer_ceil", + ), + ( + (1, 3, 3, 2), + [3, 3], + lambda x: tf.image.resize(x, [3, 3], method="nearest"), + "half_pixel", + "round_prefer_ceil", + ), ], ) -def test_resize_nearest_neighbor(input_shape, output_size, tf_op, coordinate_transformation_mode, rounding_method): +def test_resize_nearest_neighbor( + input_shape, output_size, tf_op, coordinate_transformation_mode, rounding_method +): class ResizeNearest(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) def func(self, x): return tf_op(x) - expected = _make_resize_expected(input_shape, output_size, "nearest_neighbor", coordinate_transformation_mode, rounding_method) + expected = _make_resize_expected( + input_shape, + output_size, + "nearest_neighbor", + coordinate_transformation_mode, + rounding_method, + ) verify(ResizeNearest, expected) @@ -1378,9 +1483,9 @@ class Expected: def main(x: R.Tensor((5,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - lv: R.Tuple( - R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="int32") - ) = R.topk(x, k=3, axis=-1, ret_type="both", largest=True, dtype="int32") + lv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="int32")) = ( + R.topk(x, k=3, axis=-1, ret_type="both", largest=True, dtype="int32") + ) gv: R.Tensor((3,), dtype="float32") = lv[0] R.output(gv) return gv @@ -1413,5 +1518,88 @@ def main(x: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 4), dtype="float32"): verify(OneHot, Expected) +def test_select(): + class Select(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2, 3), dtype=tf.bool), + tf.TensorSpec(shape=(2, 3), dtype=tf.float32), + tf.TensorSpec(shape=(2, 3), dtype=tf.float32), + ] + ) + def func(self, cond, x, y): + return tf.where(cond, x, y) + + @I.ir_module + class Expected: + @R.function + def main( + cond: R.Tensor((2, 3), dtype="bool"), + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.where(cond, x, y) + R.output(gv) + return gv + + verify(Select, Expected) + + +def test_depth_to_space(): + class DepthToSpace(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 2, 4, 8), dtype=tf.float32)]) + def func(self, x): + return tf.nn.depth_to_space(x, block_size=2) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 2, 4, 8), dtype="float32"), + ) -> R.Tensor((1, 4, 8, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((1, 2, 4, 2, 2, 2), dtype="float32") = R.reshape( + x, R.shape([1, 2, 4, 2, 2, 2]) + ) + lv1: R.Tensor((1, 2, 2, 4, 2, 2), dtype="float32") = R.permute_dims( + lv, axes=[0, 1, 3, 2, 4, 5] + ) + gv: R.Tensor((1, 4, 8, 2), dtype="float32") = R.reshape(lv1, R.shape([1, 4, 8, 2])) + R.output(gv) + return gv + + verify(DepthToSpace, Expected) + + +def test_space_to_depth(): + class SpaceToDepth(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 4, 4, 2), dtype=tf.float32)]) + def func(self, x): + return tf.nn.space_to_depth(x, block_size=2) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 4, 4, 2), dtype="float32"), + ) -> R.Tensor((1, 2, 2, 8), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((1, 2, 2, 2, 2, 2), dtype="float32") = R.reshape( + x, R.shape([1, 2, 2, 2, 2, 2]) + ) + lv1: R.Tensor((1, 2, 2, 2, 2, 2), dtype="float32") = R.permute_dims( + lv, axes=[0, 1, 3, 2, 4, 5] + ) + gv: R.Tensor((1, 2, 2, 8), dtype="float32") = R.reshape(lv1, R.shape([1, 2, 2, 8])) + R.output(gv) + return gv + + verify(SpaceToDepth, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__])