1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <utility>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37
tryLowerToCollapseShape(ReduceOp op,RankedTensorType argTy,Value arg,SmallVector<int64_t> & orderedReductionDims,PatternRewriter & rewriter)38 LogicalResult tryLowerToCollapseShape(
39 ReduceOp op, RankedTensorType argTy, Value arg,
40 SmallVector<int64_t>& orderedReductionDims, PatternRewriter& rewriter) {
41 // This only works for trivial reductions where all declared reduction
42 // dimensiosn are of extent 1.
43 if (!llvm::all_of(orderedReductionDims,
44 [argTy](int64_t i) { return argTy.getDimSize(i) == 1; })) {
45 return failure();
46 }
47
48 int64_t argRank = argTy.getRank();
49 int64_t numReductionDims = orderedReductionDims.size();
50
51 int64_t j = 0;
52 auto isDeclaredAsReductionDim = [&](int64_t i) {
53 if (j < numReductionDims && orderedReductionDims[j] == i) {
54 j++;
55 return true;
56 }
57 return false;
58 };
59
60 // Build reassociation indices.
61 SmallVector<ReassociationIndices, 4> reassociation;
62 int64_t iBegin = 0;
63 int64_t i = 0;
64 while (i < argRank && isDeclaredAsReductionDim(i)) i++;
65 while (i < argRank) {
66 i++;
67 while (i < argRank && isDeclaredAsReductionDim(i)) i++;
68 reassociation.push_back(llvm::to_vector(llvm::seq(iBegin, i)));
69 iBegin = i;
70 }
71
72 // Lower reduction op to collapse shape op.
73 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, arg, reassociation);
74 return success();
75 }
76
77 enum class DimensionKind {
78 kParallel,
79 kReduction,
80 kDegenerate,
81 };
82
83 struct DimensionGroup {
84 DimensionKind kind;
85 int64_t begin;
86 int64_t end;
sizemlir::mhlo::__anon2077a29b0111::DimensionGroup87 int64_t size() { return end - begin; }
88 };
89
90 // Groups consecutive dimensions of a reduction argument by their kind, i.e. if
91 // they are reduction or parallel dimensions. Dimensions of size 1 can be
92 // considered as any kind.
groupDimensions(RankedTensorType argTy,SmallVector<int64_t> orderedReductionDims,SmallVector<DimensionGroup> & groups)93 void groupDimensions(RankedTensorType argTy,
94 SmallVector<int64_t> orderedReductionDims,
95 SmallVector<DimensionGroup>& groups) {
96 int64_t argRank = argTy.getRank();
97 int64_t numReductionDims = orderedReductionDims.size();
98 int64_t j = 0;
99 for (int64_t i = 0; i < argRank; ++i) {
100 // Check if the i-th dimension is one of the declared reduction dimensions.
101 bool isDeclaredAsReductionDim = false;
102 if (j < numReductionDims && i == orderedReductionDims[j]) {
103 isDeclaredAsReductionDim = true;
104 j++;
105 }
106
107 // Use the declared dimension kind unless the dimension is of extent 1, in
108 // which case we can consider it either kind. We exploit this to form
109 // maximal dimension groups.
110 DimensionKind kind = isDeclaredAsReductionDim ? DimensionKind::kReduction
111 : DimensionKind::kParallel;
112 if (argTy.getDimSize(i) == 1) kind = DimensionKind::kDegenerate;
113
114 // Start a new dimension group if the dimenion kind conflicts with the
115 // trailing kind.
116 if (groups.empty() || (groups.back().kind != kind &&
117 groups.back().kind != DimensionKind::kDegenerate &&
118 kind != DimensionKind::kDegenerate)) {
119 groups.push_back({kind, i, i});
120 }
121
122 // Include dimension in trailing group and concretize dimension kind if
123 // necessary.
124 if (groups.back().kind == DimensionKind::kDegenerate)
125 groups.back().kind = kind;
126 groups.back().end++;
127 }
128 }
129
tryLowerTo1DOr2DReduction(ReduceOp op,RankedTensorType argTy,Value arg,SmallVector<int64_t> & orderedReductionDims,bool preferColumnsReductions,PatternRewriter & rewriter)130 LogicalResult tryLowerTo1DOr2DReduction(
131 ReduceOp op, RankedTensorType argTy, Value arg,
132 SmallVector<int64_t>& orderedReductionDims, bool preferColumnsReductions,
133 PatternRewriter& rewriter) {
134 // Group the argument dimensions by their kind.
135 SmallVector<DimensionGroup> dimGroups;
136 groupDimensions(argTy, orderedReductionDims, dimGroups);
137
138 // Do not (re-)apply if the dimensions are already fully collapsed.
139 if (dimGroups.size() <= 2 &&
140 llvm::all_of(dimGroups, [](auto g) { return g.size() == 1; })) {
141 return failure();
142 }
143
144 // Determine whether or not a dynamic reshape is needed for the final result.
145 int64_t numDynParallelDims = 0;
146 for (auto group : dimGroups) {
147 if (group.kind != DimensionKind::kParallel) continue;
148 for (int64_t i = group.begin; i < group.end; i++) {
149 if (argTy.isDynamicDim(i)) numDynParallelDims++;
150 }
151 }
152 bool requiresDynamicReshape = numDynParallelDims > 1;
153
154 // Reify the result shape early so that the pattern can fail without altering
155 // the IR.
156 Optional<Value> resultShape;
157 if (requiresDynamicReshape) {
158 llvm::SmallVector<Value, 1> reifiedShapes;
159 if (failed(llvm::cast<InferShapedTypeOpInterface>(op.getOperation())
160 .reifyReturnTypeShapes(rewriter, op->getOperands(),
161 reifiedShapes))) {
162 return failure();
163 }
164 assert(reifiedShapes.size() == 1 && "expect exactly one shape");
165 resultShape = reifiedShapes.front();
166 }
167
168 // Collapse dimension groups so that all adjacent dimensions of the
169 // intermediate result are of a different kind.
170 Value intermResult = arg;
171 auto loc = op.getLoc();
172 bool requiresCollapse =
173 llvm::any_of(dimGroups, [&](auto g) { return g.size() > 1; });
174 if (requiresCollapse) {
175 auto reassociation =
176 llvm::to_vector(llvm::map_range(dimGroups, [&](auto g) {
177 return llvm::to_vector<2>(llvm::seq<int64_t>(g.begin, g.end));
178 }));
179 intermResult = rewriter.create<tensor::CollapseShapeOp>(loc, intermResult,
180 reassociation);
181 }
182
183 // If required, transpose the intermediate result so that dimensions kinds
184 // form two partitions, which can be collapsed to a 2D intermediate result.
185 bool requiresTranspose = dimGroups.size() > 2;
186 if (requiresTranspose) {
187 // Materialize transpose.
188 DimensionKind leadingDimKind = preferColumnsReductions
189 ? DimensionKind::kReduction
190 : DimensionKind::kParallel;
191 DimensionKind trailingDimKind = preferColumnsReductions
192 ? DimensionKind::kParallel
193 : DimensionKind::kReduction;
194 SmallVector<int64_t> perm;
195 for (int64_t i = 0; i < static_cast<int64_t>(dimGroups.size()); i++) {
196 if (dimGroups[i].kind == leadingDimKind) perm.push_back(i);
197 }
198 int64_t numLeadingDims = perm.size();
199 for (int64_t i = 0; i < static_cast<int64_t>(dimGroups.size()); i++) {
200 if (dimGroups[i].kind == trailingDimKind) perm.push_back(i);
201 }
202 auto permAttr = rewriter.getI64TensorAttr(perm);
203 intermResult = rewriter.create<TransposeOp>(loc, intermResult, permAttr)
204 ->getResults()
205 .front();
206
207 // Collapse intermediate result rank 2.
208 SmallVector<ReassociationIndices, 2> reassociation = {
209 llvm::to_vector<2>(llvm::seq<int64_t>(0, numLeadingDims)),
210 llvm::to_vector<2>(llvm::seq<int64_t>(numLeadingDims, perm.size()))};
211 intermResult = rewriter.create<tensor::CollapseShapeOp>(loc, intermResult,
212 reassociation);
213 }
214
215 // Materialize inner 1D or 2D reduction.
216 bool leadingReduction =
217 requiresTranspose ? preferColumnsReductions
218 : dimGroups.front().kind == DimensionKind::kReduction;
219 int64_t reductionDim = leadingReduction ? 0 : 1;
220 auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim});
221 Value initVal = op.init_values().front();
222 auto reductionOp =
223 rewriter.create<ReduceOp>(loc, intermResult, initVal, reductionDimAttr);
224 rewriter.inlineRegionBefore(op.body(), reductionOp.body(),
225 reductionOp.body().begin());
226 intermResult = reductionOp->getResults().front();
227
228 // Restore the expected shape by dynamic reshape, if required.
229 auto resultTy = op->getResultTypes().front().cast<RankedTensorType>();
230 if (requiresDynamicReshape) {
231 assert(resultShape && "expect to have reified the result shape");
232 intermResult = rewriter.create<DynamicReshapeOp>(
233 loc, resultTy, intermResult, *resultShape);
234 }
235
236 // Othwerise, restore the expected shape by shape expansion, if required.
237 int64_t resultRank = resultTy.getRank();
238 int64_t intermResultRank =
239 intermResult.getType().cast<RankedTensorType>().getRank();
240 bool requiresExpand =
241 !requiresDynamicReshape && resultRank != intermResultRank;
242 if (requiresExpand) {
243 assert(intermResultRank <= 1 &&
244 "expect intermediate result to be of rank 0 or 1 before expansion");
245 SmallVector<ReassociationIndices, 1> reassociation;
246 bool isScalarExpansion = intermResultRank == 0;
247 if (!isScalarExpansion)
248 reassociation = {llvm::to_vector(llvm::seq<int64_t>(0, resultRank))};
249 intermResult = rewriter.create<tensor::ExpandShapeOp>(
250 loc, resultTy, intermResult, reassociation);
251 }
252
253 rewriter.replaceOp(op, intermResult);
254 return success();
255 }
256
257 struct GroupReductionDimensionsPattern : public OpRewritePattern<ReduceOp> {
GroupReductionDimensionsPatternmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPattern258 GroupReductionDimensionsPattern(MLIRContext* ctx,
259 bool preferColumnsReductions)
260 : OpRewritePattern<ReduceOp>(ctx, /*benefit=*/1),
261 preferColumnsReductions(preferColumnsReductions) {}
262
matchAndRewritemlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPattern263 LogicalResult matchAndRewrite(ReduceOp op,
264 PatternRewriter& rewriter) const override {
265 // Only apply to reduction of a unique argument.
266 if (op.operands().size() != 1 || op.init_values().size() != 1)
267 return failure();
268 Value arg = op.operands().front();
269 auto argTy = arg.getType().cast<RankedTensorType>();
270
271 // Sort reduction dimensions, which is not an invariant of the op.
272 SmallVector<int64_t> orderedReductionDims =
273 llvm::to_vector<4>(llvm::map_range(op.dimensions(), [](auto d) {
274 return static_cast<int64_t>(d.getLimitedValue());
275 }));
276 std::sort(orderedReductionDims.begin(), orderedReductionDims.end());
277
278 // If all reduction dimensions are known to be of extent 1 then we can
279 // express the reduction through an equivalent collapsing op.
280 if (succeeded(tryLowerToCollapseShape(op, argTy, arg, orderedReductionDims,
281 rewriter))) {
282 return success();
283 }
284
285 // Otherwise, try lowering the reduction to an equivalent 1D or 2D
286 // reduction, and insert transposes if needed.
287 if (succeeded(
288 tryLowerTo1DOr2DReduction(op, argTy, arg, orderedReductionDims,
289 preferColumnsReductions, rewriter))) {
290 return success();
291 }
292
293 return failure();
294 }
295
296 bool preferColumnsReductions;
297 };
298
299 struct GroupReductionDimensionsPass
300 : public GroupReductionDimensionsPassBase<GroupReductionDimensionsPass> {
GroupReductionDimensionsPassmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPass301 explicit GroupReductionDimensionsPass(bool preferColumnsReductions)
302 : GroupReductionDimensionsPassBase<
303 GroupReductionDimensionsPass>::GroupReductionDimensionsPassBase() {
304 prefer_columns_reductions_ = preferColumnsReductions;
305 }
306
runOnOperationmlir::mhlo::__anon2077a29b0111::GroupReductionDimensionsPass307 void runOnOperation() override {
308 MLIRContext* ctx = &getContext();
309 RewritePatternSet patterns(ctx);
310 populateGroupReductionDimensionsPatterns(ctx, &patterns,
311 prefer_columns_reductions_);
312 if (failed(applyPatternsAndFoldGreedily(getOperation(),
313 std::move(patterns)))) {
314 return signalPassFailure();
315 }
316 }
317 };
318
319 } // namespace
320
populateGroupReductionDimensionsPatterns(MLIRContext * context,RewritePatternSet * patterns,bool preferColumnsReductions)321 void populateGroupReductionDimensionsPatterns(MLIRContext* context,
322 RewritePatternSet* patterns,
323 bool preferColumnsReductions) {
324 patterns->add<GroupReductionDimensionsPattern>(context,
325 preferColumnsReductions);
326 }
327
createGroupReductionDimensionsPass(bool preferColumnsReductions)328 std::unique_ptr<OperationPass<func::FuncOp>> createGroupReductionDimensionsPass(
329 bool preferColumnsReductions) {
330 return std::make_unique<GroupReductionDimensionsPass>(
331 preferColumnsReductions);
332 }
333
334 } // namespace mhlo
335 } // namespace mlir
336