1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Passes.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
29
30 namespace tensorflow {
31 namespace {
32
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
35
36 using llvm::makeArrayRef;
37 using mlir::BlockAndValueMapping;
38 using mlir::BlockArgument;
39 using mlir::dyn_cast;
40 using mlir::failure;
41 using mlir::Location;
42 using mlir::LogicalResult;
43 using mlir::MLIRContext;
44 using mlir::OpBuilder;
45 using mlir::Operation;
46 using mlir::OpFoldResult;
47 using mlir::OpRewritePattern;
48 using mlir::PatternRewriter;
49 using mlir::SmallVector;
50 using mlir::success;
51 using mlir::Value;
52 using mlir::ValueRange;
53 using mlir::gml_st::LoopOp;
54 using mlir::linalg::FillOp;
55 using mlir::linalg::GenericOp;
56 using mlir::linalg::InitTensorOp;
57 using mlir::linalg::LinalgOp;
58 using mlir::linalg::YieldOp;
59 using mlir::tensor::ExtractSliceOp;
60 using mlir::tensor::InsertSliceOp;
61
GetParallelDimStep(LoopOp tiled_loop)62 SmallVector<OpFoldResult> GetParallelDimStep(LoopOp tiled_loop) {
63 assert(tiled_loop.getNumLoops() == 2 && "Expected a 2D loop");
64 Value step = tiled_loop.isParallelDimension(0) ? tiled_loop.step().front()
65 : tiled_loop.step().back();
66 if (auto constant = step.getDefiningOp<mlir::arith::ConstantOp>()) {
67 return {constant.getValue()};
68 }
69 return {step};
70 }
71
72 // Fuses `linalg.fill` into a loop with a tiled reduction.
73 // Currently, only 2D case is supported. Fusion into a tiled 1D reduction is
74 // also possible.
75 struct FuseFillIntoTiledReductionPattern : public OpRewritePattern<GenericOp> {
FuseFillIntoTiledReductionPatterntensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern76 explicit FuseFillIntoTiledReductionPattern(MLIRContext *context,
77 mlir::PatternBenefit benefit = 1)
78 : OpRewritePattern<GenericOp>(context, benefit) {}
79
matchAndRewritetensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern80 LogicalResult matchAndRewrite(GenericOp linalg_op,
81 PatternRewriter &rewriter) const override {
82 if (linalg_op.getNumOutputs() != 1) return failure();
83 if (linalg_op.getNumLoops() != 2) return failure();
84
85 // Get immediate parent.
86 auto tiled_loop_op =
87 dyn_cast<LoopOp>(linalg_op->getParentRegion()->getParentOp());
88 if (!tiled_loop_op) return failure();
89 if (tiled_loop_op.getNumLoops() != 2) return failure();
90
91 return RewriteTiledReduction(rewriter, tiled_loop_op, linalg_op);
92 }
93
94 private:
95 // Add a new output argument to the `tiled_loop`. It will be produced by
96 // `init_tensor` op with the same shape of the tiled output argument.
97 //
98 // Rewrite
99 //
100 // %init = linalg.init_tensor
101 // %fill = linalg.fill(%cst, %init)
102 // linalg.tiled_loop outs(%fill)
103 //
104 // into
105 //
106 // %init = linalg.init_tensor
107 //** %init_tile = linalg.init_tensor [%stride]
108 // %fill = linalg.fill(%cst, %init)
109 //** linalg.tiled_loop outs(%fill, %init_tile)
CloneAndAppendInitTensorToTiledLooptensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern110 BlockArgument CloneAndAppendInitTensorToTiledLoop(PatternRewriter &rewriter,
111 FillOp fill,
112 LoopOp tiled_loop) const {
113 OpBuilder::InsertionGuard guard(rewriter);
114 rewriter.setInsertionPoint(fill);
115
116 auto init = fill.output().getDefiningOp<InitTensorOp>();
117
118 Value init_clone = rewriter.create<InitTensorOp>(
119 init.getLoc(), GetParallelDimStep(tiled_loop),
120 init.getType().cast<mlir::RankedTensorType>().getElementType());
121 mlir::OpOperand *init_clone_output_operand;
122 rewriter.updateRootInPlace(tiled_loop, [&]() {
123 init_clone_output_operand =
124 &tiled_loop.appendOutputOperand(rewriter, init_clone);
125 });
126 return tiled_loop.getTiedBlockArgument(*init_clone_output_operand);
127 }
128
129 // Fuse `fill` operation into the `tiled_loop`, rewire the `linalg.generic` to
130 // use it as the output for the reduced tile. Also create an additional
131 // `insert_slice` that updates the new output.
132 //
133 // Rewrite
134 //
135 // %init = linalg.init_tensor
136 // %init_tile = linalg.init_tensor [%stride]
137 // %fill = linalg.fill(%cst, %init)
138 // linalg.tiled_loop outs(%fill, %init_tile) {
139 // %extract_output_slice = tensor.extract_slice %fill
140 // %reduce = linalg.generic outs (%extract_output_slice)
141 // %insert_output_slice = tensor.insert_slice %reduce into %fill
142 // linalg.yield %insert_output_slice
143 // }
144 //
145 // into
146 //
147 // %init = linalg.init_tensor
148 // %init_tile = linalg.init_tensor
149 // %fill = linalg.fill(%cst, %init)
150 // linalg.tiled_loop outs(%fill, %init_tile) {
151 // %extract_output_slice = tensor.extract_slice %fill
152 //
153 //** %slice_of_output_tile = tensor.extract_slice %init
154 //** %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
155 //** %reduce = linalg.generic outs (%fill_of_output_tile)
156 //** %update_output_tile = tensor.insert_slice %reduce into %init_tile
157 //
158 // %insert_output_slice = tensor.insert_slice %reduce into %fill
159 // linalg.yield %insert_output_slice, %update_output_tile
160 // }
FuseFilltensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern161 void FuseFill(PatternRewriter &rewriter, LinalgOp tiled_op, FillOp fill,
162 BlockArgument loop_output_bb_arg,
163 BlockArgument output_tile_bb_arg,
164 ExtractSliceOp extract_output_slice,
165 InsertSliceOp insert_output_slice) const {
166 Location loc = tiled_op.getLoc();
167
168 OpBuilder::InsertionGuard g(rewriter);
169 rewriter.setInsertionPoint(tiled_op);
170
171 SmallVector<OpFoldResult> offset{rewriter.getIndexAttr(0)};
172 Value slice_of_output_tile = rewriter.create<ExtractSliceOp>(
173 loc, output_tile_bb_arg, offset, extract_output_slice.getMixedSizes(),
174 extract_output_slice.getMixedStrides());
175
176 auto fused_fill =
177 rewriter.create<FillOp>(loc, fill.value(), slice_of_output_tile);
178 rewriter.updateRootInPlace(tiled_op, [&]() {
179 tiled_op.getOutputOperand(0)->set(fused_fill.result());
180 });
181
182 rewriter.setInsertionPointAfter(tiled_op);
183 Value cloned_insert = rewriter.create<mlir::tensor::InsertSliceOp>(
184 loc, fused_fill.getResult(0), output_tile_bb_arg, offset,
185 extract_output_slice.getMixedSizes(),
186 extract_output_slice.getMixedStrides());
187
188 auto yield = tiled_op.getOperation()->getBlock()->getTerminator();
189 rewriter.updateRootInPlace(
190 yield, [&]() { yield->insertOperands(1, cloned_insert); });
191 }
192
193 // Add an operation that combines the partial result with the output.
194 //
195 // Rewrite
196 //
197 // %init = linalg.init_tensor
198 // %init_tile = linalg.init_tensor
199 // %fill = linalg.fill(%cst, %init)
200 // linalg.tiled_loop outs(%fill, %init_tile) {
201 // %extract_output_slice = tensor.extract_slice %fill
202 //
203 // %slice_of_output_tile = tensor.extract_slice %init
204 // %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
205 // %reduce = linalg.generic outs (%fill_of_output_tile)
206 // %update_output_tile = tensor.insert_slice %reduce into %init_tile
207 //
208 // %insert_output_slice = tensor.insert_slice %reduce into %fill
209 // linalg.yield %insert_output_slice, %update_output_tile
210 // }
211 //
212 // into
213 //
214 // %init = linalg.init_tensor
215 // %init_tile = linalg.init_tensor
216 // %fill = linalg.fill(%cst, %init)
217 // linalg.tiled_loop outs(%fill, %init_tile) {
218 // %extract_output_slice = tensor.extract_slice %fill
219 //
220 // %slice_of_output_tile = tensor.extract_slice %init
221 // %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
222 // %reduce = linalg.generic outs (%fill_of_output_tile)
223 // %update_output_tile = tensor.insert_slice %reduce into %init_tile
224 //
225 //** %combine = linalg.generic ins (%reduce) outs (%extract_output_slice)
226 //** %insert_output_slice = tensor.insert_slice %combine into %fill
227 //
228 // linalg.yield %insert_output_slice, %update_output_tile
229 // }
CombineReducedTileWithOutputtensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern230 LogicalResult CombineReducedTileWithOutput(
231 PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result,
232 ExtractSliceOp extract_output_slice,
233 InsertSliceOp insert_output_slice) const {
234 rewriter.setInsertionPointAfter(tiled_op);
235 auto num_parallel_loops = tiled_op.getNumParallelLoops();
236 SmallVector<mlir::StringRef, 3> parallel_iter_types(
237 num_parallel_loops, mlir::getParallelIteratorTypeName());
238 auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops);
239
240 auto combiner_or = DetectCombiner(tiled_op);
241 if (failed(combiner_or)) return failure();
242 Operation *combiner = combiner_or.getValue();
243
244 auto accumulator = rewriter.create<GenericOp>(
245 tiled_op.getLoc(), partial_result.getType(),
246 makeArrayRef(partial_result),
247 makeArrayRef(extract_output_slice.getResult()),
248 makeArrayRef({id_map, id_map}), parallel_iter_types,
249 [&](OpBuilder &b, Location nested_loc, ValueRange args) {
250 BlockAndValueMapping bvm;
251 bvm.map(combiner->getOperands(), args);
252 Value result_val = b.clone(*combiner, bvm)->getResult(0);
253 b.create<YieldOp>(nested_loc, result_val);
254 });
255
256 rewriter.updateRootInPlace(insert_output_slice, [&]() {
257 insert_output_slice.getSourceMutable().assign(accumulator.getResult(0));
258 });
259 return success();
260 }
261
262 // Unfortunaly, there is no way to modify the results of the loop inplace. So
263 // we have to replace it with a clone.
CreateLoopWithUpdatedResultstensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern264 LoopOp CreateLoopWithUpdatedResults(PatternRewriter &rewriter,
265 LoopOp tiled_loop) const {
266 auto loc = tiled_loop.getLoc();
267 rewriter.setInsertionPoint(tiled_loop);
268 auto new_loop = rewriter.create<LoopOp>(
269 loc, mlir::TypeRange(tiled_loop.outputs()), tiled_loop.getOperands(),
270 tiled_loop->getAttrs());
271 rewriter.inlineRegionBefore(tiled_loop.region(), new_loop.region(),
272 new_loop.region().begin());
273
274 rewriter.replaceOp(tiled_loop, new_loop.getResult(0));
275 return new_loop;
276 }
277
278 // Fuses FillOp producer of the output argument of the LoopOp and inserts
279 // an operation that accumulates the partial result, i.e. reduced tile, and
280 // the current value of the output tile.
RewriteTiledReductiontensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern281 LogicalResult RewriteTiledReduction(PatternRewriter &rewriter,
282 LoopOp tiled_loop,
283 LinalgOp tiled_op) const {
284 OpBuilder::InsertionGuard guard(rewriter);
285 rewriter.setInsertionPointAfter(tiled_op);
286
287 // Find tiled loop output operand and the corresponding block argument.
288 mlir::OpOperand *loop_output_operand =
289 tiled_loop.findOutputOperand(tiled_loop.outputs().front());
290 BlockArgument loop_output_bb_arg =
291 tiled_loop.getTiedBlockArgument(*loop_output_operand);
292
293 // Find `linalg.fill` producer of the output.
294 auto fill = loop_output_operand->get().getDefiningOp<FillOp>();
295 if (!fill) return failure();
296
297 // Find extract_slice/insert_slice pair used to RMW output.
298 auto extract_output_slice =
299 tiled_op.getOutputOperand(0)->get().getDefiningOp<ExtractSliceOp>();
300 if (!extract_output_slice) return failure();
301
302 Value tiled_op_result = tiled_op->getResult(0);
303 auto insert_output_slice =
304 dyn_cast<InsertSliceOp>(*tiled_op_result.getUsers().begin());
305 if (!insert_output_slice) return failure();
306
307 // Fuse the output.
308 BlockArgument output_tile_bb_arg =
309 CloneAndAppendInitTensorToTiledLoop(rewriter, fill, tiled_loop);
310 FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, output_tile_bb_arg,
311 extract_output_slice, insert_output_slice);
312 // We have already modified the loop above, so we need to update the
313 // results.
314 CreateLoopWithUpdatedResults(rewriter, tiled_loop);
315 return CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result,
316 extract_output_slice,
317 insert_output_slice);
318 }
319 };
320
321 struct FuseFillIntoTiledReductionPass
322 : public FuseFillIntoTiledReductionBase<FuseFillIntoTiledReductionPass> {
runOnOperationtensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPass323 void runOnOperation() override {
324 auto func = getOperation();
325 auto context = func.getContext();
326
327 mlir::RewritePatternSet patterns(context);
328 patterns.add<FuseFillIntoTiledReductionPattern>(context);
329 (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns));
330 }
331 };
332
333 } // namespace
334
335 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseFillIntoTiledReductionPass()336 CreateFuseFillIntoTiledReductionPass() {
337 return std::make_unique<FuseFillIntoTiledReductionPass>();
338 }
339
340 } // namespace tensorflow
341