Skip to content

Commit 96805b2

Browse files
committed
mlir: enable RemoveDeadValues
1 parent 621054a commit 96805b2

5 files changed

Lines changed: 191 additions & 71 deletions

File tree

src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -687,12 +687,18 @@ namespace py {
687687
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertWithPass)
688688
};
689689

690-
// Pattern: rewrite a zero-operand func.return inside a function
691-
// whose return type expects at least one result by inserting a
692-
// None constant. Reaches for emitpybytecode::ConstantOp because
693-
// the pass runs after PythonToPythonBytecodePass has already
694-
// lowered py.constant; using py.constant here would re-introduce
695-
// an illegal source-dialect op into the lowered IR.
690+
// Pattern: rewrite a zero-operand func.return by inserting a None
691+
// constant operand and, if RemoveDeadValues also rewrote the
692+
// parent FuncOp's signature to return nothing, restoring its
693+
// declared result type to PyObjectType. The bytecode emitter
694+
// assumes every function returns a value (Python's "every
695+
// function returns at minimum None") regardless of whether MLIR
696+
// sees the result as used.
697+
//
698+
// Reaches for emitpybytecode::ConstantOp because the pass runs
699+
// after PythonToPythonBytecodePass has already lowered
700+
// py.constant; using py.constant here would re-introduce an
701+
// illegal source-dialect op into the lowered IR.
696702
struct MaterialiseReturnNonePattern : public mlir::OpRewritePattern<mlir::func::ReturnOp>
697703
{
698704
using mlir::OpRewritePattern<mlir::func::ReturnOp>::OpRewritePattern;
@@ -703,12 +709,23 @@ namespace py {
703709
if (op.getNumOperands() != 0) { return mlir::failure(); }
704710
auto parent = op->getParentOfType<mlir::func::FuncOp>();
705711
if (!parent) { return mlir::failure(); }
706-
if (parent.getFunctionType().getNumResults() == 0) { return mlir::failure(); }
712+
auto pyobject_ty = mlir::py::PyObjectType::get(rewriter.getContext());
707713
rewriter.setInsertionPoint(op);
708-
auto none = rewriter.create<mlir::emitpybytecode::ConstantOp>(op.getLoc(),
709-
mlir::py::PyObjectType::get(rewriter.getContext()),
710-
rewriter.getUnitAttr());
714+
auto none = rewriter.create<mlir::emitpybytecode::ConstantOp>(
715+
op.getLoc(), pyobject_ty, rewriter.getUnitAttr());
711716
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, mlir::ValueRange{ none });
717+
718+
// Restore the function signature if RemoveDeadValues stripped
719+
// the result type. Plain assignment to the function-type
720+
// attribute is fine here because the parent op's properties
721+
// aren't tracked by the pattern rewriter's mutation tracking
722+
// (we already produced a successful match-and-rewrite via
723+
// replaceOpWithNewOp above).
724+
if (parent.getFunctionType().getNumResults() == 0) {
725+
auto fn_ty = parent.getFunctionType();
726+
parent.setFunctionType(rewriter.getFunctionType(
727+
fn_ty.getInputs(), mlir::TypeRange{ pyobject_ty }));
728+
}
712729
return mlir::success();
713730
}
714731
};

0 commit comments

Comments
 (0)