-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathArithPatterns.cpp
More file actions
147 lines (124 loc) · 5.21 KB
/
ArithPatterns.cpp
File metadata and controls
147 lines (124 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp"
#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp"
#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp"
#include "Dialect/Python/IR/PythonOps.hpp"
#include "executable/bytecode/instructions/BinaryOperation.hpp"
#include "executable/bytecode/instructions/Unary.hpp"
#include "utilities.hpp"
#include "mlir/IR/PatternMatch.h"
namespace mlir::py {
namespace {
// Translate py.{binary,inplace_op}'s ArithOpKind enum to the
// bytecode-level BinaryOperation::Operation enum. The two are
// deliberately decoupled: the dialect enum is part of the IR
// contract; the bytecode enum is the wire format consumed by the
// VM, so the mapping between the two must stay explicit.
BinaryOperation::Operation py_kind_to_binary_op(mlir::py::ArithOpKind kind)
{
switch (kind) {
case mlir::py::ArithOpKind::add:
return BinaryOperation::Operation::PLUS;
case mlir::py::ArithOpKind::sub:
return BinaryOperation::Operation::MINUS;
case mlir::py::ArithOpKind::mod:
return BinaryOperation::Operation::MODULO;
case mlir::py::ArithOpKind::mul:
return BinaryOperation::Operation::MULTIPLY;
case mlir::py::ArithOpKind::exp:
return BinaryOperation::Operation::EXP;
case mlir::py::ArithOpKind::div:
return BinaryOperation::Operation::SLASH;
case mlir::py::ArithOpKind::fldiv:
return BinaryOperation::Operation::FLOORDIV;
case mlir::py::ArithOpKind::mmul:
return BinaryOperation::Operation::MATMUL;
case mlir::py::ArithOpKind::lshift:
return BinaryOperation::Operation::LEFTSHIFT;
case mlir::py::ArithOpKind::rshift:
return BinaryOperation::Operation::RIGHTSHIFT;
case mlir::py::ArithOpKind::and_:
return BinaryOperation::Operation::AND;
case mlir::py::ArithOpKind::or_:
return BinaryOperation::Operation::OR;
case mlir::py::ArithOpKind::xor_:
return BinaryOperation::Operation::XOR;
}
ASSERT_NOT_REACHED();
}
struct InplaceOpLowering : public mlir::OpRewritePattern<InplaceOp>
{
using OpRewritePattern<InplaceOp>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(InplaceOp op,
mlir::PatternRewriter &rewriter) const final
{
auto op_type = mlir::IntegerAttr::get(rewriter.getIntegerType(8, false),
static_cast<uint8_t>(py_kind_to_binary_op(op.getKind())));
rewriter.replaceOpWithNewOp<mlir::emitpybytecode::InplaceOp>(
op, op.getResult().getType(), op.getDst(), op.getSrc(), op_type);
return success();
}
};
struct BinaryOpLowering : public mlir::OpRewritePattern<mlir::py::BinaryOp>
{
using OpRewritePattern<mlir::py::BinaryOp>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::py::BinaryOp op,
mlir::PatternRewriter &rewriter) const final
{
auto op_type = mlir::IntegerAttr::get(rewriter.getIntegerType(8, false),
static_cast<uint8_t>(py_kind_to_binary_op(op.getKind())));
rewriter.replaceOpWithNewOp<mlir::emitpybytecode::BinaryOp>(
op, op.getOutput().getType(), op.getLhs(), op.getRhs(), op_type);
return success();
}
};
// Trivial 1:1 lowering of a py.unary_* op to emitpybytecode.UNARY_OP
// with the corresponding Unary::Operation enum baked in.
template<typename From, Unary::Operation Kind>
struct UnaryOpLowering : public mlir::OpRewritePattern<From>
{
using mlir::OpRewritePattern<From>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final
{
rewriter.template replaceOpWithNewOp<mlir::emitpybytecode::UnaryOp>(
op, op.getOutput().getType(), op.getInput(), static_cast<uint8_t>(Kind));
return mlir::success();
}
};
using PositiveOpLowering = UnaryOpLowering<mlir::py::PositiveOp, Unary::Operation::POSITIVE>;
using NegativeOpLowering = UnaryOpLowering<mlir::py::NegativeOp, Unary::Operation::NEGATIVE>;
using InvertOpLowering = UnaryOpLowering<mlir::py::InvertOp, Unary::Operation::INVERT>;
using NotOpLowering = UnaryOpLowering<mlir::py::NotOp, Unary::Operation::NOT>;
struct CompareOpLowering : public mlir::OpRewritePattern<mlir::py::CompareOp>
{
using OpRewritePattern<mlir::py::CompareOp>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::py::CompareOp op,
mlir::PatternRewriter &rewriter) const final
{
auto lhs = op.getLhs();
auto rhs = op.getRhs();
auto op_type = mlir::IntegerAttr::get(
rewriter.getIntegerType(8, false), static_cast<uint8_t>(op.getPredicate()));
rewriter.replaceOpWithNewOp<mlir::emitpybytecode::Compare>(
op, op.getOutput().getType(), lhs, rhs, op_type);
return success();
}
};
struct CastToBoolOpLowering : public mlir::OpRewritePattern<mlir::py::CastToBoolOp>
{
using OpRewritePattern<mlir::py::CastToBoolOp>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(mlir::py::CastToBoolOp op,
mlir::PatternRewriter &rewriter) const final
{
rewriter.replaceOpWithNewOp<mlir::emitpybytecode::CastToBool>(
op, op.getValue().getType(), op.getValue());
return success();
}
};
}// namespace
void populateArithPatterns(mlir::RewritePatternSet &patterns)
{
auto *ctx = patterns.getContext();
patterns.add<BinaryOpLowering, InplaceOpLowering, CompareOpLowering, CastToBoolOpLowering>(ctx);
patterns.add<PositiveOpLowering, NegativeOpLowering, InvertOpLowering, NotOpLowering>(ctx);
}
}// namespace mlir::py