1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace mlir {
28 namespace gml_st {
29 namespace {
30 
31 /// Converts gml_st.loop to SCF loop nest. All parallel dimensions are collected
32 /// into an scf.parallel loop and all sequential dimensions will result in a
33 /// nested scf.for loop nest. The pattern assumes that a gml_st.loop with
34 /// iterator_types ["reduction", "parallel", "reduction"] can be reordered.
35 struct LoopToSCFPattern : public OpRewritePattern<LoopOp> {
36   using OpRewritePattern<LoopOp>::OpRewritePattern;
37 
matchAndRewritemlir::gml_st::__anoncf4583c30111::LoopToSCFPattern38   LogicalResult matchAndRewrite(LoopOp loop,
39                                 PatternRewriter &rewriter) const override {
40     // Fail conversion if the `gml_st.loop` has not been bufferized.
41     if (!loop.hasBufferSemantics()) return failure();
42 
43     // Collect loop control parameters for parallel and sequential dimensions.
44     SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs;
45     SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs;
46     for (const auto &en :
47          llvm::enumerate(llvm::zip(loop.lowerBound(), loop.upperBound(),
48                                    loop.step(), loop.getInductionVars()))) {
49       Value lb, ub, step, iv;
50       std::tie(lb, ub, step, iv) = en.value();
51       if (loop.isParallelDimension(en.index())) {
52         parLBs.push_back(lb);
53         parUBs.push_back(ub);
54         parSteps.push_back(step);
55         parIVs.push_back(iv);
56       } else {
57         seqLBs.push_back(lb);
58         seqUBs.push_back(ub);
59         seqSteps.push_back(step);
60         seqIVs.push_back(iv);
61       }
62     }
63 
64     Location loc = loop.getLoc();
65     auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc,
66                                                ValueRange ivs) {
67       BlockAndValueMapping bvm;
68       bvm.map(parIVs, ivs);
69       bvm.map(loop.getRegionInputArgs(), loop.inputs());
70       bvm.map(loop.getRegionOutputArgs(), loop.outputs());
71 
72       // If not all dimensions of the gml_st.loop are parallel, an scf.for loop
73       // nest is generated.
74       if (!seqIVs.empty()) {
75         scf::LoopNest nest =
76             scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps,
77                                [&](OpBuilder & /*builder*/, Location /*loc*/,
78                                    ValueRange ivs) { bvm.map(seqIVs, ivs); });
79         builder.setInsertionPointToStart(nest.loops.back().getBody());
80       }
81       for (auto &op : loop.getBody()->without_terminator())
82         builder.clone(op, bvm);
83     };
84 
85     if (parIVs.empty()) {
86       generateForLoopNestAndCloneBody(rewriter, loc, llvm::None);
87     } else {
88       rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps,
89                                        generateForLoopNestAndCloneBody);
90     }
91     rewriter.eraseOp(loop);
92     return success();
93   }
94 };
95 
96 /// Converts gml_st.parallel to SCF loop nest.
97 struct ParallelOpToSCFPattern : public OpRewritePattern<ParallelOp> {
98   using OpRewritePattern<ParallelOp>::OpRewritePattern;
99 
matchAndRewritemlir::gml_st::__anoncf4583c30111::ParallelOpToSCFPattern100   LogicalResult matchAndRewrite(ParallelOp loop,
101                                 PatternRewriter &rewriter) const override {
102     // Fail conversion if the loop has not been bufferized.
103     if (!loop.hasBufferSemantics()) return failure();
104 
105     auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) {
106       BlockAndValueMapping bvm;
107       bvm.map(loop.getInductionVars(), ivs);
108 
109       for (auto &op : loop.getBody()->without_terminator())
110         builder.clone(op, bvm);
111     };
112 
113     rewriter.create<scf::ParallelOp>(loop.getLoc(), loop.lowerBound(),
114                                      loop.upperBound(), loop.step(), cloneBody);
115 
116     rewriter.eraseOp(loop);
117     return success();
118   }
119 };
120 
121 /// Converts gml_st.for to SCF loop nest.
122 struct ForOpToSCFPattern : public OpRewritePattern<ForOp> {
123   using OpRewritePattern<ForOp>::OpRewritePattern;
124 
matchAndRewritemlir::gml_st::__anoncf4583c30111::ForOpToSCFPattern125   LogicalResult matchAndRewrite(ForOp loop,
126                                 PatternRewriter &rewriter) const override {
127     // Fail conversion if the loop has not been bufferized.
128     if (!loop.hasBufferSemantics()) return failure();
129 
130     auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) {
131       BlockAndValueMapping bvm;
132       bvm.map(loop.getInductionVars(), ivs);
133       bvm.map(loop.getBody()->getArguments().take_back(loop.outputs().size()),
134               loop.outputs());
135 
136       for (auto &op : loop.getBody()->without_terminator())
137         builder.clone(op, bvm);
138     };
139 
140     scf::buildLoopNest(rewriter, loop.getLoc(), loop.lowerBound(),
141                        loop.upperBound(), loop.step(), cloneBody);
142     rewriter.eraseOp(loop);
143     return success();
144   }
145 };
146 
147 struct GmlStToScfPass : public GmlStToScfBase<GmlStToScfPass> {
runOnOperationmlir::gml_st::__anoncf4583c30111::GmlStToScfPass148   void runOnOperation() override {
149     MLIRContext *context = &getContext();
150     RewritePatternSet patterns(context);
151     patterns.add<ForOpToSCFPattern, LoopToSCFPattern, ParallelOpToSCFPattern>(
152         patterns.getContext());
153     if (failed(applyPatternsAndFoldGreedily(getOperation(),
154                                             std::move(patterns)))) {
155       signalPassFailure();
156     }
157   }
158 };
159 
160 }  // namespace
161 
createGmlStToScfPass()162 std::unique_ptr<OperationPass<func::FuncOp>> createGmlStToScfPass() {
163   return std::make_unique<GmlStToScfPass>();
164 }
165 
166 }  // namespace gml_st
167 }  // namespace mlir
168