@@ -687,12 +687,18 @@ namespace py {
687687 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (ConvertWithPass)
688688 };
689689
690- // Pattern: rewrite a zero-operand func.return inside a function
691- // whose return type expects at least one result by inserting a
692- // None constant. Reaches for emitpybytecode::ConstantOp because
693- // the pass runs after PythonToPythonBytecodePass has already
694- // lowered py.constant; using py.constant here would re-introduce
695- // an illegal source-dialect op into the lowered IR.
690+ // Pattern: rewrite a zero-operand func.return by inserting a None
691+ // constant operand and, if RemoveDeadValues also rewrote the
692+ // parent FuncOp's signature to return nothing, restoring its
693+ // declared result type to PyObjectType. The bytecode emitter
694+ // assumes every function returns a value (Python's "every
695+ // function returns at minimum None") regardless of whether MLIR
696+ // sees the result as used.
697+ //
698+ // Reaches for emitpybytecode::ConstantOp because the pass runs
699+ // after PythonToPythonBytecodePass has already lowered
700+ // py.constant; using py.constant here would re-introduce an
701+ // illegal source-dialect op into the lowered IR.
696702 struct MaterialiseReturnNonePattern : public mlir ::OpRewritePattern<mlir::func::ReturnOp>
697703 {
698704 using mlir::OpRewritePattern<mlir::func::ReturnOp>::OpRewritePattern;
@@ -703,12 +709,23 @@ namespace py {
703709 if (op.getNumOperands () != 0 ) { return mlir::failure (); }
704710 auto parent = op->getParentOfType <mlir::func::FuncOp>();
705711 if (!parent) { return mlir::failure (); }
706- if (parent. getFunctionType (). getNumResults () == 0 ) { return mlir::failure (); }
712+ auto pyobject_ty = mlir::py::PyObjectType::get (rewriter. getContext ());
707713 rewriter.setInsertionPoint (op);
708- auto none = rewriter.create <mlir::emitpybytecode::ConstantOp>(op.getLoc (),
709- mlir::py::PyObjectType::get (rewriter.getContext ()),
710- rewriter.getUnitAttr ());
714+ auto none = rewriter.create <mlir::emitpybytecode::ConstantOp>(
715+ op.getLoc (), pyobject_ty, rewriter.getUnitAttr ());
711716 rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(op, mlir::ValueRange{ none });
717+
718+ // Restore the function signature if RemoveDeadValues stripped
719+ // the result type. Plain assignment to the function-type
720+ // attribute is fine here because the parent op's properties
721+ // aren't tracked by the pattern rewriter's mutation tracking
722+ // (we already produced a successful match-and-rewrite via
723+ // replaceOpWithNewOp above).
724+ if (parent.getFunctionType ().getNumResults () == 0 ) {
725+ auto fn_ty = parent.getFunctionType ();
726+ parent.setFunctionType (rewriter.getFunctionType (
727+ fn_ty.getInputs (), mlir::TypeRange{ pyobject_ty }));
728+ }
712729 return mlir::success ();
713730 }
714731 };
0 commit comments