1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 namespace mlir {
28 namespace gml_st {
29 namespace {
30
31 /// Converts gml_st.loop to SCF loop nest. All parallel dimensions are collected
32 /// into an scf.parallel loop and all sequential dimensions will result in a
33 /// nested scf.for loop nest. The pattern assumes that a gml_st.loop with
34 /// iterator_types ["reduction", "parallel", "reduction"] can be reordered.
35 struct LoopToSCFPattern : public OpRewritePattern<LoopOp> {
36 using OpRewritePattern<LoopOp>::OpRewritePattern;
37
matchAndRewritemlir::gml_st::__anoncf4583c30111::LoopToSCFPattern38 LogicalResult matchAndRewrite(LoopOp loop,
39 PatternRewriter &rewriter) const override {
40 // Fail conversion if the `gml_st.loop` has not been bufferized.
41 if (!loop.hasBufferSemantics()) return failure();
42
43 // Collect loop control parameters for parallel and sequential dimensions.
44 SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs;
45 SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs;
46 for (const auto &en :
47 llvm::enumerate(llvm::zip(loop.lowerBound(), loop.upperBound(),
48 loop.step(), loop.getInductionVars()))) {
49 Value lb, ub, step, iv;
50 std::tie(lb, ub, step, iv) = en.value();
51 if (loop.isParallelDimension(en.index())) {
52 parLBs.push_back(lb);
53 parUBs.push_back(ub);
54 parSteps.push_back(step);
55 parIVs.push_back(iv);
56 } else {
57 seqLBs.push_back(lb);
58 seqUBs.push_back(ub);
59 seqSteps.push_back(step);
60 seqIVs.push_back(iv);
61 }
62 }
63
64 Location loc = loop.getLoc();
65 auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc,
66 ValueRange ivs) {
67 BlockAndValueMapping bvm;
68 bvm.map(parIVs, ivs);
69 bvm.map(loop.getRegionInputArgs(), loop.inputs());
70 bvm.map(loop.getRegionOutputArgs(), loop.outputs());
71
72 // If not all dimensions of the gml_st.loop are parallel, an scf.for loop
73 // nest is generated.
74 if (!seqIVs.empty()) {
75 scf::LoopNest nest =
76 scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps,
77 [&](OpBuilder & /*builder*/, Location /*loc*/,
78 ValueRange ivs) { bvm.map(seqIVs, ivs); });
79 builder.setInsertionPointToStart(nest.loops.back().getBody());
80 }
81 for (auto &op : loop.getBody()->without_terminator())
82 builder.clone(op, bvm);
83 };
84
85 if (parIVs.empty()) {
86 generateForLoopNestAndCloneBody(rewriter, loc, llvm::None);
87 } else {
88 rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps,
89 generateForLoopNestAndCloneBody);
90 }
91 rewriter.eraseOp(loop);
92 return success();
93 }
94 };
95
96 /// Converts gml_st.parallel to SCF loop nest.
97 struct ParallelOpToSCFPattern : public OpRewritePattern<ParallelOp> {
98 using OpRewritePattern<ParallelOp>::OpRewritePattern;
99
matchAndRewritemlir::gml_st::__anoncf4583c30111::ParallelOpToSCFPattern100 LogicalResult matchAndRewrite(ParallelOp loop,
101 PatternRewriter &rewriter) const override {
102 // Fail conversion if the loop has not been bufferized.
103 if (!loop.hasBufferSemantics()) return failure();
104
105 auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) {
106 BlockAndValueMapping bvm;
107 bvm.map(loop.getInductionVars(), ivs);
108
109 for (auto &op : loop.getBody()->without_terminator())
110 builder.clone(op, bvm);
111 };
112
113 rewriter.create<scf::ParallelOp>(loop.getLoc(), loop.lowerBound(),
114 loop.upperBound(), loop.step(), cloneBody);
115
116 rewriter.eraseOp(loop);
117 return success();
118 }
119 };
120
121 /// Converts gml_st.for to SCF loop nest.
122 struct ForOpToSCFPattern : public OpRewritePattern<ForOp> {
123 using OpRewritePattern<ForOp>::OpRewritePattern;
124
matchAndRewritemlir::gml_st::__anoncf4583c30111::ForOpToSCFPattern125 LogicalResult matchAndRewrite(ForOp loop,
126 PatternRewriter &rewriter) const override {
127 // Fail conversion if the loop has not been bufferized.
128 if (!loop.hasBufferSemantics()) return failure();
129
130 auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) {
131 BlockAndValueMapping bvm;
132 bvm.map(loop.getInductionVars(), ivs);
133 bvm.map(loop.getBody()->getArguments().take_back(loop.outputs().size()),
134 loop.outputs());
135
136 for (auto &op : loop.getBody()->without_terminator())
137 builder.clone(op, bvm);
138 };
139
140 scf::buildLoopNest(rewriter, loop.getLoc(), loop.lowerBound(),
141 loop.upperBound(), loop.step(), cloneBody);
142 rewriter.eraseOp(loop);
143 return success();
144 }
145 };
146
147 struct GmlStToScfPass : public GmlStToScfBase<GmlStToScfPass> {
runOnOperationmlir::gml_st::__anoncf4583c30111::GmlStToScfPass148 void runOnOperation() override {
149 MLIRContext *context = &getContext();
150 RewritePatternSet patterns(context);
151 patterns.add<ForOpToSCFPattern, LoopToSCFPattern, ParallelOpToSCFPattern>(
152 patterns.getContext());
153 if (failed(applyPatternsAndFoldGreedily(getOperation(),
154 std::move(patterns)))) {
155 signalPassFailure();
156 }
157 }
158 };
159
160 } // namespace
161
createGmlStToScfPass()162 std::unique_ptr<OperationPass<func::FuncOp>> createGmlStToScfPass() {
163 return std::make_unique<GmlStToScfPass>();
164 }
165
166 } // namespace gml_st
167 } // namespace mlir
168