1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Location.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Value.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Support/LogicalResult.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38
39 // Maximum rank that is allowed. Other Tensors should be restricted to this
40 // rank. This can be generalized as a pass parameter depending on the use-cases.
41 constexpr int64_t kMaxRank = 5;
42
43 // Rewrites Reshape -> Transpose -> Reshape sequence of ops originating from
44 // lowering of ops like SpaceToBatchND.
45 //
46 // Input to the first Reshape is Tensor in NHWC format in 4D or 5D.
47 //
48 // The first reshape splits spatial dimensions to generated two dimensions for
49 // each of the spatial dimension. Then, transpose moves the second part of the
50 // split dimensions to the beginning. The final reshape op combines the first
51 // dimension with the moved dimensions.
52 //
53 // reshape(NxHxWxC) -> (Nx(H/B1)xB1x(W/B2)xB2xC)
54 // tranpose(Nx(H/B1)xB1x(W/B2)xB2xC) -> (B1xB2xNx(H/B1)x(W/B2)xC)
55 // reshape(B1xB2xNx(H/B1)x(W/B2)xC) -> ((B1*B2*N)x(H/B1)x(W/B2)xC)
56 //
57 // Rank of the intermediate tensors can be reduced by doing one transpose for
58 // each of the spatial dims in sequence.
59 struct RewriteReshapeTransposeReshape : public OpRewritePattern<TransposeOp> {
60 using OpRewritePattern<TransposeOp>::OpRewritePattern;
61
matchAndRewritemlir::mhlo::__anon60ec8a8a0111::RewriteReshapeTransposeReshape62 LogicalResult matchAndRewrite(TransposeOp op,
63 PatternRewriter &rewriter) const override {
64 Value result = op.getResult();
65 TensorType resultTy = result.getType().cast<TensorType>();
66 Value operand = op.operand();
67 TensorType operandTy = operand.getType().cast<TensorType>();
68 if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
69 return rewriter.notifyMatchFailure(op,
70 "transpose op has non-static types");
71
72 if (resultTy.getRank() <= kMaxRank)
73 return rewriter.notifyMatchFailure(op,
74 "already has right dimensionality");
75
76 if (!operand.hasOneUse() || !result.hasOneUse())
77 return rewriter.notifyMatchFailure(
78 op, "transpose op operand and result have multiple uses");
79
80 auto defOp = operand.getDefiningOp<ReshapeOp>();
81 if (!defOp)
82 return rewriter.notifyMatchFailure(
83 op, "defining op for operand is not reshape");
84
85 auto userOp = llvm::dyn_cast<ReshapeOp>(result.use_begin().getUser());
86 if (!userOp)
87 return rewriter.notifyMatchFailure(op,
88 "user of the result is not reshape");
89
90 Value input = defOp.operand();
91 auto inputTy = input.getType().cast<TensorType>();
92 auto outputTy = userOp.getType();
93 if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
94 return rewriter.notifyMatchFailure(
95 op, "reshape op input or output type is not static");
96
97 int64_t inputRank = inputTy.getRank();
98 int64_t outputRank = outputTy.getRank();
99 if (inputRank != outputRank)
100 return rewriter.notifyMatchFailure(
101 op, "reshape op input and output rank are different");
102
103 int64_t spatialDims = inputRank - 2;
104 if (spatialDims < 0 || operandTy.getRank() != 2 + 2 * spatialDims)
105 return rewriter.notifyMatchFailure(
106 op, "transpose op operand isn't expanding spatial dims");
107
108 SmallVector<int64_t, 4> blockSizes;
109 SmallVector<int64_t, 6> expectedPerm(operandTy.getRank());
110 expectedPerm[spatialDims] = 0;
111
112 if (inputTy.getDimSize(0) != operandTy.getDimSize(0))
113 return rewriter.notifyMatchFailure(
114 op, "reshape op isn't expanding only spatial dims");
115 for (int64_t dim = 0; dim < spatialDims; dim++) {
116 int64_t blockSize = operandTy.getDimSize(1 + dim * 2 + 1);
117 if (inputTy.getDimSize(1 + dim) !=
118 operandTy.getDimSize(1 + dim * 2) * blockSize)
119 return rewriter.notifyMatchFailure(
120 op, "reshape op isn't only expanding spatial dims");
121 blockSizes.push_back(blockSize);
122
123 expectedPerm[dim] = 1 + 2 * dim + 1;
124 expectedPerm[1 + spatialDims + dim] = 1 + 2 * dim;
125 }
126 expectedPerm[1 + 2 * spatialDims] = 1 + 2 * spatialDims;
127
128 SmallVector<int64_t, 6> perm(op.permutation().getValues<int64_t>());
129 if (perm != expectedPerm)
130 return rewriter.notifyMatchFailure(
131 op, "reshape op isn't only moving spatial dims");
132
133 SmallVector<int64_t, 4> outShape;
134 outShape.push_back(resultTy.getDimSize(0));
135 for (int64_t dim = 0; dim < spatialDims; dim++) {
136 outShape[0] *= resultTy.getDimSize(1 + dim);
137 outShape.push_back(resultTy.getDimSize(1 + spatialDims + dim));
138 }
139 outShape.push_back(resultTy.getDimSize(1 + spatialDims * 2));
140 if (llvm::to_vector<4>(outputTy.getShape()) != outShape)
141 return rewriter.notifyMatchFailure(
142 op, "reshape op isn't only combining block dims");
143
144 // Now that the input patterns are verified, introduce a sequence of
145 // reshape->transpose->reshape for each of the spatial dimensions. We need
146 // to start with the last spatial dimension to preserve the sequence in the
147 // first dimension.
148 for (int dim = spatialDims - 1; dim >= 0; dim--) {
149 // 1) Reshape to split the particular spatial dimension.
150 auto inputTy = input.getType().cast<TensorType>();
151 auto intermediateShape = llvm::to_vector<4>(inputTy.getShape());
152 int64_t dimIdx = 1 + dim;
153 intermediateShape[dimIdx] /= blockSizes[dim];
154 int64_t blockIdx = dimIdx + 1;
155 intermediateShape.insert(intermediateShape.begin() + blockIdx,
156 blockSizes[dim]);
157 auto reshapedTy =
158 RankedTensorType::get(intermediateShape, inputTy.getElementType());
159
160 auto reshape = rewriter.create<ReshapeOp>(op.getLoc(), reshapedTy, input);
161
162 // 2) Transpose to move the block part of the split dimension in the
163 // beginning.
164 SmallVector<int64_t, 8> perm;
165 perm.push_back(blockIdx);
166 perm.push_back(0);
167 for (int i = 1, e = reshapedTy.getRank(); i != e; i++) {
168 if (i != perm[0]) perm.push_back(i);
169 }
170
171 auto transpose = rewriter.create<TransposeOp>(
172 op.getLoc(), reshape, rewriter.getI64TensorAttr(perm));
173
174 // 3) Reshape to combine the block dimension with the batch dimension.
175 intermediateShape = llvm::to_vector<4>(transpose.getType().getShape());
176 intermediateShape[0] *= intermediateShape[1];
177 intermediateShape.erase(intermediateShape.begin() + 1);
178 reshapedTy =
179 RankedTensorType::get(intermediateShape, inputTy.getElementType());
180
181 input = rewriter.create<ReshapeOp>(op.getLoc(), reshapedTy, transpose);
182 }
183
184 rewriter.replaceOp(userOp, input);
185 return success();
186 }
187 };
188
189 struct RestrictMaxRankPass
190 : public RestrictMaxRankPassBase<RestrictMaxRankPass> {
getDependentDialectsmlir::mhlo::__anon60ec8a8a0111::RestrictMaxRankPass191 void getDependentDialects(DialectRegistry ®istry) const override {
192 registry.insert<mhlo::MhloDialect>();
193 }
194
runOnOperationmlir::mhlo::__anon60ec8a8a0111::RestrictMaxRankPass195 void runOnOperation() override {
196 MLIRContext *ctx = &getContext();
197
198 // Collect patterns.
199 RewritePatternSet patterns(ctx);
200 patterns.add<RewriteReshapeTransposeReshape>(ctx);
201
202 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
203 GreedyRewriteConfig()))) {
204 return signalPassFailure();
205 }
206 }
207 };
208
209 } // namespace
210
createRestrictMaxRankPass()211 std::unique_ptr<OperationPass<func::FuncOp>> createRestrictMaxRankPass() {
212 return std::make_unique<RestrictMaxRankPass>();
213 }
214
215 } // namespace mhlo
216 } // namespace mlir
217