diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 91b6a3a171af..4c0edcb0f286 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -391,6 +391,17 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _logical_and(self, node: fx.Node) -> relax.Var: + lhs = self.env[node.args[0]] + rhs = self.env[node.args[1]] + # torch.logical_and accepts any dtype (treating nonzero as True) and returns bool, but + # relax.op.logical_and requires boolean inputs, so cast non-bool inputs to bool first. + if lhs.struct_info.dtype != "bool": + lhs = self.block_builder.emit(relax.op.astype(lhs, "bool")) + if rhs.struct_info.dtype != "bool": + rhs = self.block_builder.emit(relax.op.astype(rhs, "bool")) + return self.block_builder.emit(relax.op.logical_and(lhs, rhs)) + def _logical_not(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] # torch.logical_not accepts any dtype (treating nonzero as True) and returns bool, but diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 976c9d45b6f0..7924a2305c95 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1552,7 +1552,7 @@ def create_convert_map( "log10.default": self._log10, "log1p.default": self._log1p, "logical_not.default": self._logical_not, - "logical_and.default": self._binary_op(relax.op.logical_and, operator.and_), + "logical_and.default": self._logical_and, "log_softmax.int": self._log_softmax, "_log_softmax.default": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 867407193abf..4af86068d7e7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -875,6 +875,7 @@ def create_convert_map( "log2": self._log2, "log10": self._log10, "log1p": self._log1p, + "logical_and": self._logical_and, "logical_not": self._logical_not, "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 86471d892473..fa2d793f29cf 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1062,6 +1062,34 @@ def main( verify_model(LogAddExp(), example_args, {}, expected) +def test_logical_and(): + class LogicalAnd(Module): + def forward(self, lhs, rhs): + return torch.logical_and(lhs, rhs) + + @tvm.script.ir_module + class expected: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, dtype="bool") + lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_and(lv, lv1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv2,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 10, 10, dtype=torch.float32), + torch.randn(1, 3, 10, 10, dtype=torch.float32), + ) + verify_model(LogicalAnd(), example_args, {}, expected) + + def test_logical_not(): class LogicalNot(Module): def forward(self, input): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index abfb18cf412a..94cdf437739e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3527,6 +3527,31 @@ def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor( verify_model(Trunc(), input_info, {}, expected_trunc) +def test_logical_and(): + input_info = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + + class LogicalAnd(Module): + def forward(self, lhs, rhs): + return torch.logical_and(lhs, rhs) + + @tvm.script.ir_module + class expected: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, dtype="bool") + lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_and(lv, lv1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv2 + R.output(gv) + return gv + + verify_model(LogicalAnd(), input_info, {}, expected) + + def test_pow_integer(): input_info = [([4], "int64")]