1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17 
18 #include <utility>
19 
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Math/IR/Math.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 
31 namespace mlir {
32 namespace {
33 #include "generated_legalize_to_standard.inc"
34 }  // end anonymous namespace
35 namespace mhlo {
36 namespace {
37 
38 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
39  public:
40   using OpRewritePattern::OpRewritePattern;
41 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const42   LogicalResult matchAndRewrite(mhlo::CompareOp op,
43                                 PatternRewriter &rewriter) const override {
44     auto lhs = op.lhs();
45     auto rhs = op.rhs();
46     auto lhsType = lhs.getType().cast<TensorType>();
47     auto rhsType = rhs.getType().cast<TensorType>();
48 
49     // Broadcasting not supported by this rewrite.
50     if (lhsType.getShape() != rhsType.getShape()) return failure();
51 
52     if (!lhsType.getElementType().isSignlessInteger() ||
53         !rhsType.getElementType().isSignlessInteger())
54       return failure();
55 
56     Optional<arith::CmpIPredicate> comparePredicate = llvm::None;
57     switch (op.comparison_direction()) {
58       case ComparisonDirection::EQ:
59         comparePredicate = arith::CmpIPredicate::eq;
60         break;
61       case ComparisonDirection::NE:
62         comparePredicate = arith::CmpIPredicate::ne;
63         break;
64       case ComparisonDirection::LT:
65         comparePredicate = arith::CmpIPredicate::slt;
66         break;
67       case ComparisonDirection::LE:
68         comparePredicate = arith::CmpIPredicate::sle;
69         break;
70       case ComparisonDirection::GT:
71         comparePredicate = arith::CmpIPredicate::sgt;
72         break;
73       case ComparisonDirection::GE:
74         comparePredicate = arith::CmpIPredicate::sge;
75         break;
76     }
77 
78     if (!comparePredicate.has_value()) return failure();
79 
80     rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, comparePredicate.value(),
81                                                lhs, rhs);
82     return success();
83   }
84 };
85 
86 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
87  public:
88   using OpRewritePattern::OpRewritePattern;
89 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const90   LogicalResult matchAndRewrite(mhlo::CompareOp op,
91                                 PatternRewriter &rewriter) const override {
92     auto lhs = op.lhs();
93     auto rhs = op.rhs();
94     auto lhsType = lhs.getType().cast<TensorType>();
95     auto rhsType = rhs.getType().cast<TensorType>();
96 
97     // Broadcasting not supported by this rewrite.
98     if (lhsType.getShape() != rhsType.getShape()) return failure();
99 
100     if (!lhsType.getElementType().isa<FloatType>() ||
101         !rhsType.getElementType().isa<FloatType>())
102       return failure();
103 
104     Optional<arith::CmpFPredicate> comparePredicate = llvm::None;
105     switch (op.comparison_direction()) {
106       case ComparisonDirection::EQ:
107         comparePredicate = arith::CmpFPredicate::OEQ;
108         break;
109       case ComparisonDirection::NE:
110         comparePredicate = arith::CmpFPredicate::UNE;
111         break;
112       case ComparisonDirection::LT:
113         comparePredicate = arith::CmpFPredicate::OLT;
114         break;
115       case ComparisonDirection::LE:
116         comparePredicate = arith::CmpFPredicate::OLE;
117         break;
118       case ComparisonDirection::GT:
119         comparePredicate = arith::CmpFPredicate::OGT;
120         break;
121       case ComparisonDirection::GE:
122         comparePredicate = arith::CmpFPredicate::OGE;
123         break;
124     }
125 
126     if (!comparePredicate.has_value()) return failure();
127 
128     rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, comparePredicate.value(),
129                                                lhs, rhs);
130     return success();
131   }
132 };
133 
134 // Replace IotaOp with an integer constant. A ConvertOp is added to
135 // convert the integer constant to iota result type. For complex types, the real
136 // part is replaced with the generated constant and the imaginary part is
137 // replaced with zero tensor.
138 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
139  public:
140   using OpRewritePattern::OpRewritePattern;
141 
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const142   LogicalResult matchAndRewrite(mhlo::IotaOp op,
143                                 PatternRewriter &rewriter) const override {
144     auto outputType = op.getType().cast<ShapedType>();
145     auto outputSize = outputType.getNumElements();
146     auto dimension = op.iota_dimension();
147     auto maxDimSize = outputType.getDimSize(dimension);
148 
149     auto elementType = outputType.getElementType();
150     int bitwidth;
151 
152     auto complexTy = elementType.dyn_cast<ComplexType>();
153     Type intOrFloatTy = elementType;
154     if (complexTy) intOrFloatTy = complexTy.getElementType();
155 
156     bitwidth = intOrFloatTy.getIntOrFloatBitWidth();
157     llvm::SmallVector<APInt, 10> values;
158     values.reserve(outputSize);
159 
160     int64_t increaseStride = outputSize;
161     for (uint64_t i = 0; i <= dimension; i++) {
162       increaseStride /= outputType.getDimSize(i);
163     }
164 
165     int64_t currentValue = 0;
166     for (int i = 0; i < outputSize; i++) {
167       int64_t value = (currentValue / increaseStride) % maxDimSize;
168       values.push_back(APInt(bitwidth, value));
169       ++currentValue;
170     }
171 
172     auto intShapeType = RankedTensorType::get(
173         outputType.getShape(),
174         IntegerType::get(rewriter.getContext(), bitwidth));
175     auto loc = op.getLoc();
176     auto integerConst = rewriter.create<mlir::arith::ConstantOp>(
177         loc, DenseIntElementsAttr::get(intShapeType, values));
178 
179     auto intOrFloatShapeTy =
180         RankedTensorType::get(outputType.getShape(), intOrFloatTy);
181 
182     auto iotaConst =
183         rewriter.create<ConvertOp>(loc, intOrFloatShapeTy, integerConst);
184 
185     // For int/float types we are done, replace op and return.
186     if (!complexTy) {
187       rewriter.replaceOp(op, iotaConst.getResult());
188       return success();
189     }
190 
191     // For complex types, generate a constant tensor of zeroes for the imaginary
192     // part and use iota_const for real part.
193     auto zeroes = rewriter.create<mlir::arith::ConstantOp>(
194         loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0)));
195     auto imagZeroes =
196         rewriter.create<ConvertOp>(loc, intOrFloatShapeTy, zeroes);
197     rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iotaConst, imagZeroes);
198     return success();
199   }
200 };
201 
202 }  // end anonymous namespace
203 
204 namespace {
205 struct LegalizeToStandardPass
206     : public LegalizeToStandardPassBase<LegalizeToStandardPass> {
getDependentDialectsmlir::mhlo::__anonc4a93fc90311::LegalizeToStandardPass207   void getDependentDialects(DialectRegistry &registry) const override {
208     registry.insert<arith::ArithmeticDialect, math::MathDialect,
209                     func::FuncDialect>();
210   }
211 
212   /// Perform the lowering to Standard dialect.
213   void runOnOperation() override;
214 };
215 }  // end anonymous namespace
216 
217 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeToStdPass()218 createLegalizeToStdPass() {
219   return std::make_unique<LegalizeToStandardPass>();
220 }
221 
populateMhloToStdPatterns(RewritePatternSet * patterns,mlir::MLIRContext * ctx)222 void populateMhloToStdPatterns(RewritePatternSet *patterns,
223                                mlir::MLIRContext *ctx) {
224   mlir::populateWithGenerated(*patterns);
225   patterns->add<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
226 }
227 
228 /// Perform the lowering to standard dialect.
runOnOperation()229 void LegalizeToStandardPass::runOnOperation() {
230   RewritePatternSet patterns(&getContext());
231   mlir::mhlo::populateMhloToStdPatterns(&patterns, &getContext());
232   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
233     return signalPassFailure();
234 }
235 
236 }  // end namespace mhlo
237 }  // end namespace mlir
238