1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to SCF dialect.
17 #include <utility>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/StringSwitch.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
29 #include "mlir/IR/Block.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/Diagnostics.h"
34 #include "mlir/IR/ImplicitLocOpBuilder.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeRange.h"
37 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Pass/PassRegistry.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/LogicalResult.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47 
48 // All transformations in this file take mhlo blocks which end with
49 // mhlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an
50 // entire block with the only change being return -> yield.
inlineMhloRegionIntoSCFRegion(PatternRewriter & rewriter,Region & mhlo,Region & scf)51 void inlineMhloRegionIntoSCFRegion(PatternRewriter& rewriter, Region& mhlo,
52                                    Region& scf) {
53   // Remove an existing block, then move the region over.
54   if (!scf.empty()) rewriter.eraseBlock(&scf.back());
55   rewriter.inlineRegionBefore(mhlo, scf, scf.end());
56   // Fix up the terminator.
57   PatternRewriter::InsertionGuard guard(rewriter);
58   rewriter.setInsertionPointToEnd(&scf.back());
59   auto* terminator = scf.back().getTerminator();
60   rewriter.replaceOpWithNewOp<scf::YieldOp>(terminator,
61                                             terminator->getOperands());
62 }
63 
64 // mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor
65 // or a 1 element tensor. To handle this, collapse shape before extracting the
66 // scalar value when necessary.
extractTensorValue(OpBuilder & b,Value tensor)67 Value extractTensorValue(OpBuilder& b, Value tensor) {
68   auto loc = tensor.getLoc();
69   if (tensor.getType().cast<TensorType>().hasRank() &&
70       tensor.getType().cast<TensorType>().getRank() != 0) {
71     tensor = b.create<tensor::CollapseShapeOp>(
72         loc, tensor, SmallVector<ReassociationIndices>());
73   }
74   return b.create<tensor::ExtractOp>(loc, tensor, ValueRange());
75 }
76 
77 // Create a memref descriptor given a pointer and memref type information.
78 struct WhileOpPattern : public OpConversionPattern<mhlo::WhileOp> {
79   using OpConversionPattern<WhileOp>::OpConversionPattern;
80 
matchAndRewritemlir::mhlo::__anon7502c60e0111::WhileOpPattern81   LogicalResult matchAndRewrite(
82       mhlo::WhileOp op, OpAdaptor adaptor,
83       ConversionPatternRewriter& rewriter) const override {
84     auto loc = op.getLoc();
85 
86     auto newWhileOp = rewriter.create<scf::WhileOp>(loc, op.getResultTypes(),
87                                                     adaptor.getOperands());
88 
89     // Inline while condition. The block is the same, except the boolean result
90     // needs to be extracted and used with an scf.condition.
91     rewriter.inlineRegionBefore(op.cond(), newWhileOp.getBefore(),
92                                 newWhileOp.getBefore().end());
93     auto conditionReturn =
94         cast<mhlo::ReturnOp>(newWhileOp.getBefore().front().getTerminator());
95     rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front());
96     Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0));
97     rewriter.replaceOpWithNewOp<scf::ConditionOp>(
98         conditionReturn, i1, newWhileOp.getBeforeArguments());
99 
100     // Inline while body, and only replace the mhlo.return with an scf.yield.
101     inlineMhloRegionIntoSCFRegion(rewriter, op.body(), newWhileOp.getAfter());
102 
103     rewriter.replaceOp(op, newWhileOp.getResults());
104     return success();
105   }
106 };
107 
108 // Create a memref descriptor given a pointer and memref type information.
109 struct IfOpPattern : public OpConversionPattern<mhlo::IfOp> {
110   using OpConversionPattern<IfOp>::OpConversionPattern;
111 
matchAndRewritemlir::mhlo::__anon7502c60e0111::IfOpPattern112   LogicalResult matchAndRewrite(
113       mhlo::IfOp op, OpAdaptor adaptor,
114       ConversionPatternRewriter& rewriter) const override {
115     auto scfIf =
116         rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
117                                    extractTensorValue(rewriter, adaptor.pred()),
118                                    /*withElseRegion=*/true);
119     inlineMhloRegionIntoSCFRegion(rewriter, op.true_branch(),
120                                   scfIf.getThenRegion());
121     inlineMhloRegionIntoSCFRegion(rewriter, op.false_branch(),
122                                   scfIf.getElseRegion());
123     rewriter.replaceOp(op, scfIf.getResults());
124     return success();
125   }
126 };
127 
128 // Create a memref descriptor given a pointer and memref type information.
129 struct CaseOpPattern : public OpConversionPattern<mhlo::CaseOp> {
130   using OpConversionPattern<CaseOp>::OpConversionPattern;
131 
132   // Recursively create if/else ops to handle each possible value in a case op.
createNestedCasesmlir::mhlo::__anon7502c60e0111::CaseOpPattern133   scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor,
134                               PatternRewriter& outerBuilder) const {
135     Location loc = op.getLoc();
136     Value idxValue = adaptor.index();
137     auto finalIdx = op.branches().size() - 2;
138 
139     // Determine if the current index matches the case index.
140     auto scalarType = idxValue.getType();
141     auto constAttr = DenseElementsAttr::get(
142         scalarType,
143         {outerBuilder.getI32IntegerAttr(currentIdx).cast<mlir::Attribute>()});
144     Value currentIdxVal = outerBuilder.create<mhlo::ConstantOp>(
145         loc, idxValue.getType(), constAttr);
146 
147     auto scfIf = outerBuilder.create<scf::IfOp>(
148         loc, op.getResultTypes(),
149         extractTensorValue(outerBuilder, outerBuilder.create<mhlo::CompareOp>(
150                                              loc, idxValue, currentIdxVal,
151                                              ComparisonDirection::EQ)),
152         /*withElseRegion=*/true);
153     inlineMhloRegionIntoSCFRegion(outerBuilder, op.branches()[currentIdx],
154                                   scfIf.getThenRegion());
155     int nextIdx = currentIdx + 1;
156     // Don't recurse for the final default block.
157     if (currentIdx == static_cast<int64_t>(finalIdx)) {
158       inlineMhloRegionIntoSCFRegion(outerBuilder, op.branches()[nextIdx],
159                                     scfIf.getElseRegion());
160     } else {
161       PatternRewriter::InsertionGuard guard(outerBuilder);
162       outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back());
163       auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder);
164       outerBuilder.create<scf::YieldOp>(op.getLoc(), innerIf.getResults());
165     }
166     return scfIf;
167   }
168 
matchAndRewritemlir::mhlo::__anon7502c60e0111::CaseOpPattern169   LogicalResult matchAndRewrite(
170       mhlo::CaseOp op, OpAdaptor adaptor,
171       ConversionPatternRewriter& rewriter) const override {
172     // Inline the op if there is only a default block.
173     if (op.branches().size() == 1) {
174       Block& block = op.branches().front().front();
175       auto results = block.getTerminator()->getOperands();
176       // Remove the mhlo.return terminator, then inline the block.
177       rewriter.eraseOp(block.getTerminator());
178       rewriter.mergeBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(),
179                                 /*argValues=*/{});
180       rewriter.replaceOp(op, results);
181       return success();
182     }
183 
184     // Begin recursion with case 0.
185     rewriter.replaceOp(
186         op, createNestedCases(0, op, adaptor, rewriter).getResults());
187     return success();
188   }
189 };
190 
191 struct LegalizeControlFlowPass
192     : public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
193   // Perform the lowering to MLIR control flow.
runOnOperationmlir::mhlo::__anon7502c60e0111::LegalizeControlFlowPass194   void runOnOperation() override {
195     func::FuncOp f = getOperation();
196     MLIRContext* ctx = f.getContext();
197 
198     RewritePatternSet patterns(&getContext());
199     patterns.add<WhileOpPattern, IfOpPattern, CaseOpPattern>(&getContext());
200 
201     mlir::ConversionTarget target(*ctx);
202     target.markUnknownOpDynamicallyLegal([](Operation*) { return true; });
203     target.addIllegalOp<mhlo::IfOp, mhlo::WhileOp, mhlo::CaseOp>();
204 
205     if (failed(applyPartialConversion(f, target, std::move(patterns)))) {
206       signalPassFailure();
207     }
208   }
209 };
210 
211 }  // namespace
212 }  // namespace mhlo
213 }  // namespace mlir
214 
215 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeControlFlowPass()216 mlir::mhlo::createLegalizeControlFlowPass() {
217   return std::make_unique<LegalizeControlFlowPass>();
218 }
219