1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Passes.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
35 
36 using llvm::makeArrayRef;
37 using mlir::BlockAndValueMapping;
38 using mlir::BlockArgument;
39 using mlir::dyn_cast;
40 using mlir::failure;
41 using mlir::Location;
42 using mlir::LogicalResult;
43 using mlir::MLIRContext;
44 using mlir::OpBuilder;
45 using mlir::Operation;
46 using mlir::OpFoldResult;
47 using mlir::OpRewritePattern;
48 using mlir::PatternRewriter;
49 using mlir::SmallVector;
50 using mlir::success;
51 using mlir::Value;
52 using mlir::ValueRange;
53 using mlir::gml_st::LoopOp;
54 using mlir::linalg::FillOp;
55 using mlir::linalg::GenericOp;
56 using mlir::linalg::InitTensorOp;
57 using mlir::linalg::LinalgOp;
58 using mlir::linalg::YieldOp;
59 using mlir::tensor::ExtractSliceOp;
60 using mlir::tensor::InsertSliceOp;
61 
GetParallelDimStep(LoopOp tiled_loop)62 SmallVector<OpFoldResult> GetParallelDimStep(LoopOp tiled_loop) {
63   assert(tiled_loop.getNumLoops() == 2 && "Expected a 2D loop");
64   Value step = tiled_loop.isParallelDimension(0) ? tiled_loop.step().front()
65                                                  : tiled_loop.step().back();
66   if (auto constant = step.getDefiningOp<mlir::arith::ConstantOp>()) {
67     return {constant.getValue()};
68   }
69   return {step};
70 }
71 
72 // Fuses `linalg.fill` into a loop with a tiled reduction.
73 // Currently, only 2D case is supported. Fusion into a tiled 1D reduction is
74 // also possible.
75 struct FuseFillIntoTiledReductionPattern : public OpRewritePattern<GenericOp> {
FuseFillIntoTiledReductionPatterntensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern76   explicit FuseFillIntoTiledReductionPattern(MLIRContext *context,
77                                              mlir::PatternBenefit benefit = 1)
78       : OpRewritePattern<GenericOp>(context, benefit) {}
79 
matchAndRewritetensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern80   LogicalResult matchAndRewrite(GenericOp linalg_op,
81                                 PatternRewriter &rewriter) const override {
82     if (linalg_op.getNumOutputs() != 1) return failure();
83     if (linalg_op.getNumLoops() != 2) return failure();
84 
85     // Get immediate parent.
86     auto tiled_loop_op =
87         dyn_cast<LoopOp>(linalg_op->getParentRegion()->getParentOp());
88     if (!tiled_loop_op) return failure();
89     if (tiled_loop_op.getNumLoops() != 2) return failure();
90 
91     return RewriteTiledReduction(rewriter, tiled_loop_op, linalg_op);
92   }
93 
94  private:
95   // Add a new output argument to the `tiled_loop`. It will be produced by
96   // `init_tensor` op with the same shape of the tiled output argument.
97   //
98   // Rewrite
99   //
100   //   %init = linalg.init_tensor
101   //   %fill = linalg.fill(%cst, %init)
102   //   linalg.tiled_loop outs(%fill)
103   //
104   // into
105   //
106   //   %init = linalg.init_tensor
107   //** %init_tile = linalg.init_tensor [%stride]
108   //   %fill = linalg.fill(%cst, %init)
109   //** linalg.tiled_loop outs(%fill, %init_tile)
CloneAndAppendInitTensorToTiledLooptensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern110   BlockArgument CloneAndAppendInitTensorToTiledLoop(PatternRewriter &rewriter,
111                                                     FillOp fill,
112                                                     LoopOp tiled_loop) const {
113     OpBuilder::InsertionGuard guard(rewriter);
114     rewriter.setInsertionPoint(fill);
115 
116     auto init = fill.output().getDefiningOp<InitTensorOp>();
117 
118     Value init_clone = rewriter.create<InitTensorOp>(
119         init.getLoc(), GetParallelDimStep(tiled_loop),
120         init.getType().cast<mlir::RankedTensorType>().getElementType());
121     mlir::OpOperand *init_clone_output_operand;
122     rewriter.updateRootInPlace(tiled_loop, [&]() {
123       init_clone_output_operand =
124           &tiled_loop.appendOutputOperand(rewriter, init_clone);
125     });
126     return tiled_loop.getTiedBlockArgument(*init_clone_output_operand);
127   }
128 
129   // Fuse `fill` operation into the `tiled_loop`, rewire the `linalg.generic` to
130   // use it as the output for the reduced tile. Also create an additional
131   // `insert_slice` that updates the new output.
132   //
133   // Rewrite
134   //
135   // %init = linalg.init_tensor
136   // %init_tile = linalg.init_tensor [%stride]
137   // %fill = linalg.fill(%cst, %init)
138   // linalg.tiled_loop outs(%fill, %init_tile) {
139   //   %extract_output_slice = tensor.extract_slice %fill
140   //   %reduce = linalg.generic outs (%extract_output_slice)
141   //   %insert_output_slice = tensor.insert_slice %reduce into %fill
142   //   linalg.yield %insert_output_slice
143   // }
144   //
145   // into
146   //
147   // %init = linalg.init_tensor
148   // %init_tile = linalg.init_tensor
149   // %fill = linalg.fill(%cst, %init)
150   // linalg.tiled_loop outs(%fill, %init_tile) {
151   //   %extract_output_slice = tensor.extract_slice %fill
152   //
153   //** %slice_of_output_tile = tensor.extract_slice %init
154   //** %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
155   //** %reduce = linalg.generic outs (%fill_of_output_tile)
156   //** %update_output_tile = tensor.insert_slice %reduce into %init_tile
157   //
158   //   %insert_output_slice = tensor.insert_slice %reduce into %fill
159   //   linalg.yield %insert_output_slice, %update_output_tile
160   // }
FuseFilltensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern161   void FuseFill(PatternRewriter &rewriter, LinalgOp tiled_op, FillOp fill,
162                 BlockArgument loop_output_bb_arg,
163                 BlockArgument output_tile_bb_arg,
164                 ExtractSliceOp extract_output_slice,
165                 InsertSliceOp insert_output_slice) const {
166     Location loc = tiled_op.getLoc();
167 
168     OpBuilder::InsertionGuard g(rewriter);
169     rewriter.setInsertionPoint(tiled_op);
170 
171     SmallVector<OpFoldResult> offset{rewriter.getIndexAttr(0)};
172     Value slice_of_output_tile = rewriter.create<ExtractSliceOp>(
173         loc, output_tile_bb_arg, offset, extract_output_slice.getMixedSizes(),
174         extract_output_slice.getMixedStrides());
175 
176     auto fused_fill =
177         rewriter.create<FillOp>(loc, fill.value(), slice_of_output_tile);
178     rewriter.updateRootInPlace(tiled_op, [&]() {
179       tiled_op.getOutputOperand(0)->set(fused_fill.result());
180     });
181 
182     rewriter.setInsertionPointAfter(tiled_op);
183     Value cloned_insert = rewriter.create<mlir::tensor::InsertSliceOp>(
184         loc, fused_fill.getResult(0), output_tile_bb_arg, offset,
185         extract_output_slice.getMixedSizes(),
186         extract_output_slice.getMixedStrides());
187 
188     auto yield = tiled_op.getOperation()->getBlock()->getTerminator();
189     rewriter.updateRootInPlace(
190         yield, [&]() { yield->insertOperands(1, cloned_insert); });
191   }
192 
193   // Add an operation that combines the partial result with the output.
194   //
195   // Rewrite
196   //
197   // %init = linalg.init_tensor
198   // %init_tile = linalg.init_tensor
199   // %fill = linalg.fill(%cst, %init)
200   // linalg.tiled_loop outs(%fill, %init_tile) {
201   //   %extract_output_slice = tensor.extract_slice %fill
202   //
203   //   %slice_of_output_tile = tensor.extract_slice %init
204   //   %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
205   //   %reduce = linalg.generic outs (%fill_of_output_tile)
206   //   %update_output_tile = tensor.insert_slice %reduce into %init_tile
207   //
208   //   %insert_output_slice = tensor.insert_slice %reduce into %fill
209   //   linalg.yield %insert_output_slice, %update_output_tile
210   // }
211   //
212   // into
213   //
214   // %init = linalg.init_tensor
215   // %init_tile = linalg.init_tensor
216   // %fill = linalg.fill(%cst, %init)
217   // linalg.tiled_loop outs(%fill, %init_tile) {
218   //   %extract_output_slice = tensor.extract_slice %fill
219   //
220   //   %slice_of_output_tile = tensor.extract_slice %init
221   //   %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile)
222   //   %reduce = linalg.generic outs (%fill_of_output_tile)
223   //   %update_output_tile = tensor.insert_slice %reduce into %init_tile
224   //
225   //** %combine = linalg.generic ins (%reduce) outs (%extract_output_slice)
226   //** %insert_output_slice = tensor.insert_slice %combine into %fill
227   //
228   //   linalg.yield %insert_output_slice, %update_output_tile
229   // }
CombineReducedTileWithOutputtensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern230   LogicalResult CombineReducedTileWithOutput(
231       PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result,
232       ExtractSliceOp extract_output_slice,
233       InsertSliceOp insert_output_slice) const {
234     rewriter.setInsertionPointAfter(tiled_op);
235     auto num_parallel_loops = tiled_op.getNumParallelLoops();
236     SmallVector<mlir::StringRef, 3> parallel_iter_types(
237         num_parallel_loops, mlir::getParallelIteratorTypeName());
238     auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops);
239 
240     auto combiner_or = DetectCombiner(tiled_op);
241     if (failed(combiner_or)) return failure();
242     Operation *combiner = combiner_or.getValue();
243 
244     auto accumulator = rewriter.create<GenericOp>(
245         tiled_op.getLoc(), partial_result.getType(),
246         makeArrayRef(partial_result),
247         makeArrayRef(extract_output_slice.getResult()),
248         makeArrayRef({id_map, id_map}), parallel_iter_types,
249         [&](OpBuilder &b, Location nested_loc, ValueRange args) {
250           BlockAndValueMapping bvm;
251           bvm.map(combiner->getOperands(), args);
252           Value result_val = b.clone(*combiner, bvm)->getResult(0);
253           b.create<YieldOp>(nested_loc, result_val);
254         });
255 
256     rewriter.updateRootInPlace(insert_output_slice, [&]() {
257       insert_output_slice.getSourceMutable().assign(accumulator.getResult(0));
258     });
259     return success();
260   }
261 
262   // Unfortunaly, there is no way to modify the results of the loop inplace. So
263   // we have to replace it with a clone.
CreateLoopWithUpdatedResultstensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern264   LoopOp CreateLoopWithUpdatedResults(PatternRewriter &rewriter,
265                                       LoopOp tiled_loop) const {
266     auto loc = tiled_loop.getLoc();
267     rewriter.setInsertionPoint(tiled_loop);
268     auto new_loop = rewriter.create<LoopOp>(
269         loc, mlir::TypeRange(tiled_loop.outputs()), tiled_loop.getOperands(),
270         tiled_loop->getAttrs());
271     rewriter.inlineRegionBefore(tiled_loop.region(), new_loop.region(),
272                                 new_loop.region().begin());
273 
274     rewriter.replaceOp(tiled_loop, new_loop.getResult(0));
275     return new_loop;
276   }
277 
278   // Fuses FillOp producer of the output argument of the LoopOp and inserts
279   // an operation that accumulates the partial result, i.e. reduced tile, and
280   // the current value of the output tile.
RewriteTiledReductiontensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPattern281   LogicalResult RewriteTiledReduction(PatternRewriter &rewriter,
282                                       LoopOp tiled_loop,
283                                       LinalgOp tiled_op) const {
284     OpBuilder::InsertionGuard guard(rewriter);
285     rewriter.setInsertionPointAfter(tiled_op);
286 
287     // Find tiled loop output operand and the corresponding block argument.
288     mlir::OpOperand *loop_output_operand =
289         tiled_loop.findOutputOperand(tiled_loop.outputs().front());
290     BlockArgument loop_output_bb_arg =
291         tiled_loop.getTiedBlockArgument(*loop_output_operand);
292 
293     // Find `linalg.fill` producer of the output.
294     auto fill = loop_output_operand->get().getDefiningOp<FillOp>();
295     if (!fill) return failure();
296 
297     // Find extract_slice/insert_slice pair used to RMW output.
298     auto extract_output_slice =
299         tiled_op.getOutputOperand(0)->get().getDefiningOp<ExtractSliceOp>();
300     if (!extract_output_slice) return failure();
301 
302     Value tiled_op_result = tiled_op->getResult(0);
303     auto insert_output_slice =
304         dyn_cast<InsertSliceOp>(*tiled_op_result.getUsers().begin());
305     if (!insert_output_slice) return failure();
306 
307     // Fuse the output.
308     BlockArgument output_tile_bb_arg =
309         CloneAndAppendInitTensorToTiledLoop(rewriter, fill, tiled_loop);
310     FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, output_tile_bb_arg,
311              extract_output_slice, insert_output_slice);
312     // We have already modified the loop above, so we need to update the
313     // results.
314     CreateLoopWithUpdatedResults(rewriter, tiled_loop);
315     return CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result,
316                                         extract_output_slice,
317                                         insert_output_slice);
318   }
319 };
320 
321 struct FuseFillIntoTiledReductionPass
322     : public FuseFillIntoTiledReductionBase<FuseFillIntoTiledReductionPass> {
runOnOperationtensorflow::__anonaacf1f630111::FuseFillIntoTiledReductionPass323   void runOnOperation() override {
324     auto func = getOperation();
325     auto context = func.getContext();
326 
327     mlir::RewritePatternSet patterns(context);
328     patterns.add<FuseFillIntoTiledReductionPattern>(context);
329     (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns));
330   }
331 };
332 
333 }  // namespace
334 
335 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseFillIntoTiledReductionPass()336 CreateFuseFillIntoTiledReductionPass() {
337   return std::make_unique<FuseFillIntoTiledReductionPass>();
338 }
339 
340 }  // namespace tensorflow
341