1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17
18 #include <utility>
19
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Math/IR/Math.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31 namespace mlir {
32 namespace {
33 #include "generated_legalize_to_standard.inc"
34 } // end anonymous namespace
35 namespace mhlo {
36 namespace {
37
38 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
39 public:
40 using OpRewritePattern::OpRewritePattern;
41
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const42 LogicalResult matchAndRewrite(mhlo::CompareOp op,
43 PatternRewriter &rewriter) const override {
44 auto lhs = op.lhs();
45 auto rhs = op.rhs();
46 auto lhsType = lhs.getType().cast<TensorType>();
47 auto rhsType = rhs.getType().cast<TensorType>();
48
49 // Broadcasting not supported by this rewrite.
50 if (lhsType.getShape() != rhsType.getShape()) return failure();
51
52 if (!lhsType.getElementType().isSignlessInteger() ||
53 !rhsType.getElementType().isSignlessInteger())
54 return failure();
55
56 Optional<arith::CmpIPredicate> comparePredicate = llvm::None;
57 switch (op.comparison_direction()) {
58 case ComparisonDirection::EQ:
59 comparePredicate = arith::CmpIPredicate::eq;
60 break;
61 case ComparisonDirection::NE:
62 comparePredicate = arith::CmpIPredicate::ne;
63 break;
64 case ComparisonDirection::LT:
65 comparePredicate = arith::CmpIPredicate::slt;
66 break;
67 case ComparisonDirection::LE:
68 comparePredicate = arith::CmpIPredicate::sle;
69 break;
70 case ComparisonDirection::GT:
71 comparePredicate = arith::CmpIPredicate::sgt;
72 break;
73 case ComparisonDirection::GE:
74 comparePredicate = arith::CmpIPredicate::sge;
75 break;
76 }
77
78 if (!comparePredicate.has_value()) return failure();
79
80 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, comparePredicate.value(),
81 lhs, rhs);
82 return success();
83 }
84 };
85
86 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
87 public:
88 using OpRewritePattern::OpRewritePattern;
89
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const90 LogicalResult matchAndRewrite(mhlo::CompareOp op,
91 PatternRewriter &rewriter) const override {
92 auto lhs = op.lhs();
93 auto rhs = op.rhs();
94 auto lhsType = lhs.getType().cast<TensorType>();
95 auto rhsType = rhs.getType().cast<TensorType>();
96
97 // Broadcasting not supported by this rewrite.
98 if (lhsType.getShape() != rhsType.getShape()) return failure();
99
100 if (!lhsType.getElementType().isa<FloatType>() ||
101 !rhsType.getElementType().isa<FloatType>())
102 return failure();
103
104 Optional<arith::CmpFPredicate> comparePredicate = llvm::None;
105 switch (op.comparison_direction()) {
106 case ComparisonDirection::EQ:
107 comparePredicate = arith::CmpFPredicate::OEQ;
108 break;
109 case ComparisonDirection::NE:
110 comparePredicate = arith::CmpFPredicate::UNE;
111 break;
112 case ComparisonDirection::LT:
113 comparePredicate = arith::CmpFPredicate::OLT;
114 break;
115 case ComparisonDirection::LE:
116 comparePredicate = arith::CmpFPredicate::OLE;
117 break;
118 case ComparisonDirection::GT:
119 comparePredicate = arith::CmpFPredicate::OGT;
120 break;
121 case ComparisonDirection::GE:
122 comparePredicate = arith::CmpFPredicate::OGE;
123 break;
124 }
125
126 if (!comparePredicate.has_value()) return failure();
127
128 rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, comparePredicate.value(),
129 lhs, rhs);
130 return success();
131 }
132 };
133
134 // Replace IotaOp with an integer constant. A ConvertOp is added to
135 // convert the integer constant to iota result type. For complex types, the real
136 // part is replaced with the generated constant and the imaginary part is
137 // replaced with zero tensor.
138 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
139 public:
140 using OpRewritePattern::OpRewritePattern;
141
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const142 LogicalResult matchAndRewrite(mhlo::IotaOp op,
143 PatternRewriter &rewriter) const override {
144 auto outputType = op.getType().cast<ShapedType>();
145 auto outputSize = outputType.getNumElements();
146 auto dimension = op.iota_dimension();
147 auto maxDimSize = outputType.getDimSize(dimension);
148
149 auto elementType = outputType.getElementType();
150 int bitwidth;
151
152 auto complexTy = elementType.dyn_cast<ComplexType>();
153 Type intOrFloatTy = elementType;
154 if (complexTy) intOrFloatTy = complexTy.getElementType();
155
156 bitwidth = intOrFloatTy.getIntOrFloatBitWidth();
157 llvm::SmallVector<APInt, 10> values;
158 values.reserve(outputSize);
159
160 int64_t increaseStride = outputSize;
161 for (uint64_t i = 0; i <= dimension; i++) {
162 increaseStride /= outputType.getDimSize(i);
163 }
164
165 int64_t currentValue = 0;
166 for (int i = 0; i < outputSize; i++) {
167 int64_t value = (currentValue / increaseStride) % maxDimSize;
168 values.push_back(APInt(bitwidth, value));
169 ++currentValue;
170 }
171
172 auto intShapeType = RankedTensorType::get(
173 outputType.getShape(),
174 IntegerType::get(rewriter.getContext(), bitwidth));
175 auto loc = op.getLoc();
176 auto integerConst = rewriter.create<mlir::arith::ConstantOp>(
177 loc, DenseIntElementsAttr::get(intShapeType, values));
178
179 auto intOrFloatShapeTy =
180 RankedTensorType::get(outputType.getShape(), intOrFloatTy);
181
182 auto iotaConst =
183 rewriter.create<ConvertOp>(loc, intOrFloatShapeTy, integerConst);
184
185 // For int/float types we are done, replace op and return.
186 if (!complexTy) {
187 rewriter.replaceOp(op, iotaConst.getResult());
188 return success();
189 }
190
191 // For complex types, generate a constant tensor of zeroes for the imaginary
192 // part and use iota_const for real part.
193 auto zeroes = rewriter.create<mlir::arith::ConstantOp>(
194 loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0)));
195 auto imagZeroes =
196 rewriter.create<ConvertOp>(loc, intOrFloatShapeTy, zeroes);
197 rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iotaConst, imagZeroes);
198 return success();
199 }
200 };
201
202 } // end anonymous namespace
203
204 namespace {
205 struct LegalizeToStandardPass
206 : public LegalizeToStandardPassBase<LegalizeToStandardPass> {
getDependentDialectsmlir::mhlo::__anonc4a93fc90311::LegalizeToStandardPass207 void getDependentDialects(DialectRegistry ®istry) const override {
208 registry.insert<arith::ArithmeticDialect, math::MathDialect,
209 func::FuncDialect>();
210 }
211
212 /// Perform the lowering to Standard dialect.
213 void runOnOperation() override;
214 };
215 } // end anonymous namespace
216
217 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeToStdPass()218 createLegalizeToStdPass() {
219 return std::make_unique<LegalizeToStandardPass>();
220 }
221
populateMhloToStdPatterns(RewritePatternSet * patterns,mlir::MLIRContext * ctx)222 void populateMhloToStdPatterns(RewritePatternSet *patterns,
223 mlir::MLIRContext *ctx) {
224 mlir::populateWithGenerated(*patterns);
225 patterns->add<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
226 }
227
228 /// Perform the lowering to standard dialect.
runOnOperation()229 void LegalizeToStandardPass::runOnOperation() {
230 RewritePatternSet patterns(&getContext());
231 mlir::mhlo::populateMhloToStdPatterns(&patterns, &getContext());
232 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
233 return signalPassFailure();
234 }
235
236 } // end namespace mhlo
237 } // end namespace mlir
238