xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/tile_loops_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This files implements the logic for converting `scf.parallel` loops into
17 // tiled loops.
18 
19 #include <cstdint>
20 #include <tuple>
21 #include <utility>
22 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/iterator_range.h"
26 #include "mlir-hlo/Transforms/PassDetail.h"
27 #include "mlir-hlo/Transforms/passes.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
33 #include "mlir/Dialect/SCF/Utils/Utils.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/Value.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37 
38 namespace mlir {
39 
40 using ::mlir::scf::ParallelOp;
41 
42 namespace {
43 
44 // This is the implementation of the TileLoops pass declared in
45 //  include/mlir-hlo/Transforms/passes.td
46 class TileLoopsPass : public TileLoopsPassBase<TileLoopsPass> {
47  public:
48   // Creates a TileLoopsPass with tiles sizes provided through `tile_sizes`
49   // and unroll factors provided through `unroll_factors`.
TileLoopsPass(ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> unrollFactors)50   explicit TileLoopsPass(ArrayRef<int64_t> tileSizes,
51                          ArrayRef<int64_t> unrollFactors) {
52     tile_sizes_ = tileSizes;
53     unroll_factors_ = unrollFactors;
54   }
55 
56   void runOnOperation() override;
57 };
58 
59 }  // namespace
60 
61 // Returns whether the access pattern in `ploop` is "complex". That is, whether
62 // any memref.load op in its region uses indices that don't correspond to the
63 // loop induction variables.
isComplexAccessPattern(ParallelOp ploop)64 static bool isComplexAccessPattern(ParallelOp ploop) {
65   auto isComplex = [&](memref::LoadOp loadOp) {
66     if (!loadOp.getMemRefType().getLayout().isIdentity()) return true;
67     if (loadOp.getIndices().empty()) return false;
68     return loadOp.getIndices() != ploop.getInductionVars();
69   };
70   return llvm::any_of(ploop.getBody()->getOps<memref::LoadOp>(), isComplex);
71 }
72 
runOnOperation()73 void TileLoopsPass::runOnOperation() {
74   SmallVector<int64_t> unrolledTile;
75   if (tile_sizes_.size() == unroll_factors_.size()) {
76     unrolledTile.reserve(tile_sizes_.size());
77     for (int64_t i = 0; i < static_cast<int64_t>(tile_sizes_.size()); ++i)
78       unrolledTile.push_back(tile_sizes_[i] * unroll_factors_[i]);
79   }
80 
81   SmallVector<ParallelOp, 2> ploops;
82   getInnermostParallelLoops(this->getOperation().getOperation(), ploops);
83   for (ParallelOp ploop : ploops) {
84     // Do not unroll if the tiling and unrolling have different rank, or if
85     // the access pattern is complex.
86     if (unrolledTile.empty() || isComplexAccessPattern(ploop)) {
87       tileParallelLoop(ploop, tile_sizes_, /*noMinMaxBounds=*/false);
88       continue;
89     }
90 
91     // Collect lower/upper bounds and step size, if they are constants.
92     auto getConstDefOps = [](OperandRange operands) {
93       return llvm::to_vector(llvm::map_range(operands, [&](Value value) {
94         return value.getDefiningOp<arith::ConstantIndexOp>();
95       }));
96     };
97     auto lower = getConstDefOps(ploop.getLowerBound());
98     auto upper = getConstDefOps(ploop.getUpperBound());
99     auto step = getConstDefOps(ploop.getStep());
100 
101     bool noMinMaxBounds = false;
102     ploop = tileParallelLoop(ploop, unrolledTile, noMinMaxBounds).second;
103     ploop = tileParallelLoop(ploop, unroll_factors_, noMinMaxBounds).second;
104 
105     // Use static upper bound on unrolled loop if possible. That is, if the
106     // unroll factor evenly divides the iteration size of the outer ploop.
107     OpBuilder builder(ploop);
108     Location loc = ploop.getLoc();
109     for (int64_t i = 0; i < static_cast<int64_t>(unrolledTile.size()); ++i) {
110       if (!lower[i] || !upper[i] || !step[i]) continue;
111       int64_t unrollFactor = unroll_factors_[i];
112       int64_t difference = upper[i].value() - lower[i].value();
113       if (difference % (step[i].value() * unrollFactor) != 0) continue;
114       ploop.getUpperBoundMutable().slice(i, 1).assign(
115           builder.create<arith::ConstantIndexOp>(loc, unrollFactor));
116     }
117   }
118 
119   // Apply arithmetic dialect canonicalizations so that
120   // ParallelToGpuLaunchLowering can derive loop-invariant upper bound for
121   // number of iterations.
122   RewritePatternSet patterns(&getContext());
123   getContext()
124       .getOrLoadDialect<arith::ArithmeticDialect>()
125       ->getCanonicalizationPatterns(patterns);
126   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
127     return signalPassFailure();
128 }
129 
createTileLoopsPass(ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> unrollFactors)130 std::unique_ptr<OperationPass<func::FuncOp>> createTileLoopsPass(
131     ArrayRef<int64_t> tileSizes, ArrayRef<int64_t> unrollFactors) {
132   return std::make_unique<TileLoopsPass>(tileSizes, unrollFactors);
133 }
134 
135 }  // namespace mlir
136