diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 275e162b818b..a3730e8a0c0f 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -707,6 +707,73 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in verify(TfInput, Expected) +def test_fully_connected(): + class FullyConnected(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8), dtype=tf.float32)]) + def func(self, x): + weight = tf.constant(np.arange(24, dtype=np.float32).reshape((3, 8))) + bias = tf.constant(np.array([0.5, 1.0, -1.0], dtype=np.float32)) + out = tf.matmul(x, weight, transpose_b=True) + return tf.nn.bias_add(out, bias) + + verify(FullyConnected) + + +def test_depthwise_conv2d(): + class DepthwiseConv2D(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32), + tf.TensorSpec(shape=(3, 3, 2, 1), dtype=tf.float32), + ] + ) + def func(self, data, kernel): + return tf.nn.depthwise_conv2d( + input=data, + filter=kernel, + strides=[1, 1, 1, 1], + padding="SAME", + ) + + verify(DepthwiseConv2D) + + +def test_transpose_conv(): + class TransposeConv(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32), + tf.TensorSpec(shape=(3, 3, 3, 2), dtype=tf.float32), + ] + ) + def func(self, data, kernel): + output_shape = tf.constant([1, 8, 8, 3], dtype=tf.int32) + return tf.nn.conv2d_transpose( + input=data, + filters=kernel, + output_shape=output_shape, + strides=[1, 1, 1, 1], + padding="SAME", + ) + + verify(TransposeConv) + + +def test_l2_pool2d(): + class L2Pool2D(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)]) + def func(self, data): + return tf.nn.pool( + input=data, + window_shape=[2, 2], + pooling_type="AVG", + strides=[1, 1], + padding="SAME", + ) + + verify(L2Pool2D) + + def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function(