1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <utility>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37 
tryLowerToCollapseShape(ReduceOp op,RankedTensorType argTy,Value arg,SmallVector<int64_t> & orderedReductionDims,PatternRewriter & rewriter)38 LogicalResult tryLowerToCollapseShape(
39     ReduceOp op, RankedTensorType argTy, Value arg,
40     SmallVector<int64_t>& orderedReductionDims, PatternRewriter& rewriter) {
41   // This only works for trivial reductions where all declared reduction
42   // dimensiosn are of extent 1.
43   if (!llvm::all_of(orderedReductionDims,
44                     [argTy](int64_t i) { return argTy.getDimSize(i) == 1; })) {
45     return failure();
46   }
47 
48   int64_t argRank = argTy.getRank();
49   int64_t numReductionDims = orderedReductionDims.size();
50 
51   int64_t j = 0;
52   auto isDeclaredAsReductionDim = [&](int64_t i) {
53     if (j < numReductionDims && orderedReductionDims[j] == i) {
54       j++;
55       return true;
56     }
57     return false;
58   };
59 
60   // Build reassociation indices.
61   SmallVector<ReassociationIndices, 4> reassociation;
62   int64_t iBegin = 0;
63   int64_t i = 0;
64   while (i < argRank && isDeclaredAsReductionDim(i)) i++;
65   while (i < argRank) {
66     i++;
67     while (i < argRank && isDeclaredAsReductionDim(i)) i++;
68     reassociation.push_back(llvm::to_vector(llvm::seq(iBegin, i)));
69     iBegin = i;
70   }
71 
72   // Lower reduction op to collapse shape op.
73   rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, arg, reassociation);
74   return success();
75 }
76 
77 enum class DimensionKind {
78   kParallel,
79   kReduction,
80   kDegenerate,
81 };
82 
83 struct DimensionGroup {
84   DimensionKind kind;
85   int64_t begin;
86   int64_t end;
sizemlir::mhlo::__anon2077a29b0111::DimensionGroup87   int64_t size() { return end - begin; }
88 };
89 
90 // Groups consecutive dimensions of a reduction argument by their kind, i.e. if
91 // they are reduction or parallel dimensions. Dimensions of size 1 can be
92 // considered as any kind.
groupDimensions(RankedTensorType argTy,SmallVector<int64_t> orderedReductionDims,SmallVector<DimensionGroup> & groups)93 void groupDimensions(RankedTensorType argTy,
94                      SmallVector<int64_t> orderedReductionDims,
95                      SmallVector<DimensionGroup>& groups) {
96   int64_t argRank = argTy.getRank();
97   int64_t numReductionDims = orderedReductionDims.size();
98   int64_t j = 0;
99   for (int64_t i = 0; i < argRank; ++i) {
100     // Check if the i-th dimension is one of the declared reduction dimensions.
101     bool isDeclaredAsReductionDim = false;
102     if (j < numReductionDims && i == orderedReductionDims[j]) {
103       isDeclaredAsReductionDim = true;
104       j++;
105     }
106 
107     // Use the declared dimension kind unless the dimension is of extent 1, in
108     // which case we can consider it either kind. We exploit this to form
109     // maximal dimension groups.
110     DimensionKind kind = isDeclaredAsReductionDim ? DimensionKind::kReduction
111                                                   : DimensionKind::kParallel;
112     if (argTy.getDimSize(i) == 1) kind = DimensionKind::kDegenerate;
113 
114     // Start a new dimension group if the dimenion kind conflicts with the
115     // trailing kind.
116     if (groups.empty() || (groups.back().kind != kind &&
117                            groups.back().kind != DimensionKind::kDegenerate &&
118                            kind != DimensionKind::kDegenerate)) {
119       groups.push_back({kind, i, i});
120     }
121 
122     // Include dimension in trailing group and concretize dimension kind if
123     // necessary.
124     if (groups.back().kind == DimensionKind::kDegenerate)
125       groups.back().kind = kind;
126     groups.back().end++;
127   }
128 }
129 
tryLowerTo1DOr2DReduction(ReduceOp op,RankedTensorType argTy,Value arg,SmallVector<int64_t> & orderedReductionDims,bool preferColumnsReductions,PatternRewriter & rewriter)130 LogicalResult tryLowerTo1DOr2DReduction(
131     ReduceOp op, RankedTensorType argTy, Value arg,
132     SmallVector<int64_t>& orderedReductionDims, bool preferColumnsReductions,
133     PatternRewriter& rewriter) {
134   // Group the argument dimensions by their kind.
135   SmallVector<DimensionGroup> dimGroups;
136   groupDimensions(argTy, orderedReductionDims, dimGroups);
137 
138   // Do not (re-)apply if the dimensions are already fully collapsed.
139   if (dimGroups.size() <= 2 &&
140       llvm::all_of(dimGroups, [](auto g) { return g.size() == 1; })) {
141     return failure();
142   }
143 
144   // Determine whether or not a dynamic reshape is needed for the final result.
145   int64_t numDynParallelDims = 0;
146   for (auto group : dimGroups) {
147     if (group.kind != DimensionKind::kParallel) continue;
148     for (int64_t i = group.begin; i < group.end; i++) {
149       if (argTy.isDynamicDim(i)) numDynParallelDims++;
150     }
151   }
152   bool requiresDynamicReshape = numDynParallelDims > 1;
153 
154   // Reify the result shape early so that the pattern can fail without altering
155   // the IR.
156   Optional<Value> resultShape;
157   if (requiresDynamicReshape) {
158     llvm::SmallVector<Value, 1> reifiedShapes;
159     if (failed(llvm::cast<InferShapedTypeOpInterface>(op.getOperation())
160                    .reifyReturnTypeShapes(rewriter, op->getOperands(),
161                                           reifiedShapes))) {
162       return failure();
163     }
164     assert(reifiedShapes.size() == 1 && "expect exactly one shape");
165     resultShape = reifiedShapes.front();
166   }
167 
168   // Collapse dimension groups so that all adjacent dimensions of the
169   // intermediate result are of a different kind.
170   Value intermResult = arg;
171   auto loc = op.getLoc();
172   bool requiresCollapse =
173       llvm::any_of(dimGroups, [&](auto g) { return g.size() > 1; });
174   if (requiresCollapse) {
175     auto reassociation =
176         llvm::to_vector(llvm::map_range(dimGroups, [&](auto g) {
177           return llvm::to_vector<2>(llvm::seq<int64_t>(g.begin, g.end));
178         }));
179     intermResult = rewriter.create<tensor::CollapseShapeOp>(loc, intermResult,
180                                                             reassociation);
181   }
182 
183   // If required, transpose the intermediate result so that dimensions kinds
184   // form two partitions, which can be collapsed to a 2D intermediate result.
185   bool requiresTranspose = dimGroups.size() > 2;
186   if (requiresTranspose) {
187     // Materialize transpose.
188     DimensionKind leadingDimKind = preferColumnsReductions
189                                        ? DimensionKind::kReduction
190                                        : DimensionKind::kParallel;
191     DimensionKind trailingDimKind = preferColumnsReductions
192                                         ? DimensionKind::kParallel
193                                         : DimensionKind::kReduction;
194     SmallVector<int64_t> perm;
195     for (int64_t i = 0; i < static_cast<int64_t>(dimGroups.size()); i++) {
196       if (dimGroups[i].kind == leadingDimKind) perm.push_back(i);
197     }
198     int64_t numLeadingDims = perm.size();
199     for (int64_t i = 0; i < static_cast<int64_t>(dimGroups.size()); i++) {
200       if (dimGroups[i].kind == trailingDimKind) perm.push_back(i);
201     }
202     auto permAttr = rewriter.getI64TensorAttr(perm);
203     intermResult = rewriter.create<TransposeOp>(loc, intermResult, permAttr)
204                        ->getResults()
205                        .front();
206 
207     // Collapse intermediate result rank 2.
208     SmallVector<ReassociationIndices, 2> reassociation = {
209         llvm::to_vector<2>(llvm::seq<int64_t>(0, numLeadingDims)),
210         llvm::to_vector<2>(llvm::seq<int64_t>(numLeadingDims, perm.size()))};
211     intermResult = rewriter.create<tensor::CollapseShapeOp>(loc, intermResult,
212                                                             reassociation);
213   }
214 
215   // Materialize inner 1D or 2D reduction.
216   bool leadingReduction =
217       requiresTranspose ? preferColumnsReductions
218                         : dimGroups.front().kind == DimensionKind::kReduction;
219   int64_t reductionDim = leadingReduction ? 0 : 1;
220   auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim});
221   Value initVal = op.init_values().front();
222   auto reductionOp =
223       rewriter.create<ReduceOp>(loc, intermResult, initVal, reductionDimAttr);
224   rewriter.inlineRegionBefore(op.body(), reductionOp.body(),
225                               reductionOp.body().begin());
226   intermResult = reductionOp->getResults().front();
227 
228   // Restore the expected shape by dynamic reshape, if required.
229   auto resultTy = op->getResultTypes().front().cast<RankedTensorType>();
230   if (requiresDynamicReshape) {
231     assert(resultShape && "expect to have reified the result shape");
232     intermResult = rewriter.create<DynamicReshapeOp>(
233         loc, resultTy, intermResult, *resultShape);
234   }
235 
236   // Othwerise, restore the expected shape by shape expansion, if required.
237   int64_t resultRank = resultTy.getRank();
238   int64_t intermResultRank =
239       intermResult.getType().cast<RankedTensorType>().getRank();
240   bool requiresExpand =
241       !requiresDynamicReshape && resultRank != intermResultRank;
242   if (requiresExpand) {
243     assert(intermResultRank <= 1 &&
244            "expect intermediate result to be of rank 0 or 1 before expansion");
245     SmallVector<ReassociationIndices, 1> reassociation;
246     bool isScalarExpansion = intermResultRank == 0;
247     if (!isScalarExpansion)
248       reassociation = {llvm::to_vector(llvm::seq<int64_t>(0, resultRank))};
249     intermResult = rewriter.create<tensor::ExpandShapeOp>(
250         loc, resultTy, intermResult, reassociation);
251   }
252 
253   rewriter.replaceOp(op, intermResult);
254   return success();
255 }
256 
257 struct GroupReductionDimensionsPattern : public OpRewritePattern<ReduceOp> {
GroupReductionDimensionsPatternmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPattern258   GroupReductionDimensionsPattern(MLIRContext* ctx,
259                                   bool preferColumnsReductions)
260       : OpRewritePattern<ReduceOp>(ctx, /*benefit=*/1),
261         preferColumnsReductions(preferColumnsReductions) {}
262 
matchAndRewritemlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPattern263   LogicalResult matchAndRewrite(ReduceOp op,
264                                 PatternRewriter& rewriter) const override {
265     // Only apply to reduction of a unique argument.
266     if (op.operands().size() != 1 || op.init_values().size() != 1)
267       return failure();
268     Value arg = op.operands().front();
269     auto argTy = arg.getType().cast<RankedTensorType>();
270 
271     // Sort reduction dimensions, which is not an invariant of the op.
272     SmallVector<int64_t> orderedReductionDims =
273         llvm::to_vector<4>(llvm::map_range(op.dimensions(), [](auto d) {
274           return static_cast<int64_t>(d.getLimitedValue());
275         }));
276     std::sort(orderedReductionDims.begin(), orderedReductionDims.end());
277 
278     // If all reduction dimensions are known to be of extent 1 then we can
279     // express the reduction through an equivalent collapsing op.
280     if (succeeded(tryLowerToCollapseShape(op, argTy, arg, orderedReductionDims,
281                                           rewriter))) {
282       return success();
283     }
284 
285     // Otherwise, try lowering the reduction to an equivalent 1D or 2D
286     // reduction, and insert transposes if needed.
287     if (succeeded(
288             tryLowerTo1DOr2DReduction(op, argTy, arg, orderedReductionDims,
289                                       preferColumnsReductions, rewriter))) {
290       return success();
291     }
292 
293     return failure();
294   }
295 
296   bool preferColumnsReductions;
297 };
298 
299 struct GroupReductionDimensionsPass
300     : public GroupReductionDimensionsPassBase<GroupReductionDimensionsPass> {
GroupReductionDimensionsPassmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPass301   explicit GroupReductionDimensionsPass(bool preferColumnsReductions)
302       : GroupReductionDimensionsPassBase<
303             GroupReductionDimensionsPass>::GroupReductionDimensionsPassBase() {
304     prefer_columns_reductions_ = preferColumnsReductions;
305   }
306 
runOnOperationmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPass307   void runOnOperation() override {
308     MLIRContext* ctx = &getContext();
309     RewritePatternSet patterns(ctx);
310     populateGroupReductionDimensionsPatterns(ctx, &patterns,
311                                              prefer_columns_reductions_);
312     if (failed(applyPatternsAndFoldGreedily(getOperation(),
313                                             std::move(patterns)))) {
314       return signalPassFailure();
315     }
316   }
317 };
318 
319 }  // namespace
320 
populateGroupReductionDimensionsPatterns(MLIRContext * context,RewritePatternSet * patterns,bool preferColumnsReductions)321 void populateGroupReductionDimensionsPatterns(MLIRContext* context,
322                                               RewritePatternSet* patterns,
323                                               bool preferColumnsReductions) {
324   patterns->add<GroupReductionDimensionsPattern>(context,
325                                                  preferColumnsReductions);
326 }
327 
createGroupReductionDimensionsPass(bool preferColumnsReductions)328 std::unique_ptr<OperationPass<func::FuncOp>> createGroupReductionDimensionsPass(
329     bool preferColumnsReductions) {
330   return std::make_unique<GroupReductionDimensionsPass>(
331       preferColumnsReductions);
332 }
333 
334 }  // namespace mhlo
335 }  // namespace mlir
336