1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Location.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Value.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Support/LogicalResult.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38 
39 // Maximum rank that is allowed. Other Tensors should be restricted to this
40 // rank. This can be generalized as a pass parameter depending on the use-cases.
41 constexpr int64_t kMaxRank = 5;
42 
43 // Rewrites Reshape -> Transpose -> Reshape sequence of ops originating from
44 // lowering of ops like SpaceToBatchND.
45 //
46 // Input to the first Reshape is Tensor in NHWC format in 4D or 5D.
47 //
48 // The first reshape splits spatial dimensions to generated two dimensions for
49 // each of the spatial dimension. Then, transpose moves the second part of the
50 // split dimensions to the beginning. The final reshape op combines the first
51 // dimension with the moved dimensions.
52 //
53 // reshape(NxHxWxC) -> (Nx(H/B1)xB1x(W/B2)xB2xC)
54 // tranpose(Nx(H/B1)xB1x(W/B2)xB2xC) -> (B1xB2xNx(H/B1)x(W/B2)xC)
55 // reshape(B1xB2xNx(H/B1)x(W/B2)xC) -> ((B1*B2*N)x(H/B1)x(W/B2)xC)
56 //
57 // Rank of the intermediate tensors can be reduced by doing one transpose for
58 // each of the spatial dims in sequence.
59 struct RewriteReshapeTransposeReshape : public OpRewritePattern<TransposeOp> {
60   using OpRewritePattern<TransposeOp>::OpRewritePattern;
61 
matchAndRewritemlir::mhlo::__anon60ec8a8a0111::RewriteReshapeTransposeReshape62   LogicalResult matchAndRewrite(TransposeOp op,
63                                 PatternRewriter &rewriter) const override {
64     Value result = op.getResult();
65     TensorType resultTy = result.getType().cast<TensorType>();
66     Value operand = op.operand();
67     TensorType operandTy = operand.getType().cast<TensorType>();
68     if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
69       return rewriter.notifyMatchFailure(op,
70                                          "transpose op has non-static types");
71 
72     if (resultTy.getRank() <= kMaxRank)
73       return rewriter.notifyMatchFailure(op,
74                                          "already has right dimensionality");
75 
76     if (!operand.hasOneUse() || !result.hasOneUse())
77       return rewriter.notifyMatchFailure(
78           op, "transpose op operand and result have multiple uses");
79 
80     auto defOp = operand.getDefiningOp<ReshapeOp>();
81     if (!defOp)
82       return rewriter.notifyMatchFailure(
83           op, "defining op for operand is not reshape");
84 
85     auto userOp = llvm::dyn_cast<ReshapeOp>(result.use_begin().getUser());
86     if (!userOp)
87       return rewriter.notifyMatchFailure(op,
88                                          "user of the result is not reshape");
89 
90     Value input = defOp.operand();
91     auto inputTy = input.getType().cast<TensorType>();
92     auto outputTy = userOp.getType();
93     if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
94       return rewriter.notifyMatchFailure(
95           op, "reshape op input or output type is not static");
96 
97     int64_t inputRank = inputTy.getRank();
98     int64_t outputRank = outputTy.getRank();
99     if (inputRank != outputRank)
100       return rewriter.notifyMatchFailure(
101           op, "reshape op input and output rank are different");
102 
103     int64_t spatialDims = inputRank - 2;
104     if (spatialDims < 0 || operandTy.getRank() != 2 + 2 * spatialDims)
105       return rewriter.notifyMatchFailure(
106           op, "transpose op operand isn't expanding spatial dims");
107 
108     SmallVector<int64_t, 4> blockSizes;
109     SmallVector<int64_t, 6> expectedPerm(operandTy.getRank());
110     expectedPerm[spatialDims] = 0;
111 
112     if (inputTy.getDimSize(0) != operandTy.getDimSize(0))
113       return rewriter.notifyMatchFailure(
114           op, "reshape op isn't expanding only spatial dims");
115     for (int64_t dim = 0; dim < spatialDims; dim++) {
116       int64_t blockSize = operandTy.getDimSize(1 + dim * 2 + 1);
117       if (inputTy.getDimSize(1 + dim) !=
118           operandTy.getDimSize(1 + dim * 2) * blockSize)
119         return rewriter.notifyMatchFailure(
120             op, "reshape op isn't only expanding spatial dims");
121       blockSizes.push_back(blockSize);
122 
123       expectedPerm[dim] = 1 + 2 * dim + 1;
124       expectedPerm[1 + spatialDims + dim] = 1 + 2 * dim;
125     }
126     expectedPerm[1 + 2 * spatialDims] = 1 + 2 * spatialDims;
127 
128     SmallVector<int64_t, 6> perm(op.permutation().getValues<int64_t>());
129     if (perm != expectedPerm)
130       return rewriter.notifyMatchFailure(
131           op, "reshape op isn't only moving spatial dims");
132 
133     SmallVector<int64_t, 4> outShape;
134     outShape.push_back(resultTy.getDimSize(0));
135     for (int64_t dim = 0; dim < spatialDims; dim++) {
136       outShape[0] *= resultTy.getDimSize(1 + dim);
137       outShape.push_back(resultTy.getDimSize(1 + spatialDims + dim));
138     }
139     outShape.push_back(resultTy.getDimSize(1 + spatialDims * 2));
140     if (llvm::to_vector<4>(outputTy.getShape()) != outShape)
141       return rewriter.notifyMatchFailure(
142           op, "reshape op isn't only combining block dims");
143 
144     // Now that the input patterns are verified, introduce a sequence of
145     // reshape->transpose->reshape for each of the spatial dimensions.  We need
146     // to start with the last spatial dimension to preserve the sequence in the
147     // first dimension.
148     for (int dim = spatialDims - 1; dim >= 0; dim--) {
149       // 1) Reshape to split the particular spatial dimension.
150       auto inputTy = input.getType().cast<TensorType>();
151       auto intermediateShape = llvm::to_vector<4>(inputTy.getShape());
152       int64_t dimIdx = 1 + dim;
153       intermediateShape[dimIdx] /= blockSizes[dim];
154       int64_t blockIdx = dimIdx + 1;
155       intermediateShape.insert(intermediateShape.begin() + blockIdx,
156                                blockSizes[dim]);
157       auto reshapedTy =
158           RankedTensorType::get(intermediateShape, inputTy.getElementType());
159 
160       auto reshape = rewriter.create<ReshapeOp>(op.getLoc(), reshapedTy, input);
161 
162       // 2) Transpose to move the block part of the split dimension in the
163       // beginning.
164       SmallVector<int64_t, 8> perm;
165       perm.push_back(blockIdx);
166       perm.push_back(0);
167       for (int i = 1, e = reshapedTy.getRank(); i != e; i++) {
168         if (i != perm[0]) perm.push_back(i);
169       }
170 
171       auto transpose = rewriter.create<TransposeOp>(
172           op.getLoc(), reshape, rewriter.getI64TensorAttr(perm));
173 
174       // 3) Reshape to combine the block dimension with the batch dimension.
175       intermediateShape = llvm::to_vector<4>(transpose.getType().getShape());
176       intermediateShape[0] *= intermediateShape[1];
177       intermediateShape.erase(intermediateShape.begin() + 1);
178       reshapedTy =
179           RankedTensorType::get(intermediateShape, inputTy.getElementType());
180 
181       input = rewriter.create<ReshapeOp>(op.getLoc(), reshapedTy, transpose);
182     }
183 
184     rewriter.replaceOp(userOp, input);
185     return success();
186   }
187 };
188 
189 struct RestrictMaxRankPass
190     : public RestrictMaxRankPassBase<RestrictMaxRankPass> {
getDependentDialectsmlir::mhlo::__anon60ec8a8a0111::RestrictMaxRankPass191   void getDependentDialects(DialectRegistry &registry) const override {
192     registry.insert<mhlo::MhloDialect>();
193   }
194 
runOnOperationmlir::mhlo::__anon60ec8a8a0111::RestrictMaxRankPass195   void runOnOperation() override {
196     MLIRContext *ctx = &getContext();
197 
198     // Collect patterns.
199     RewritePatternSet patterns(ctx);
200     patterns.add<RewriteReshapeTransposeReshape>(ctx);
201 
202     if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
203                                             GreedyRewriteConfig()))) {
204       return signalPassFailure();
205     }
206   }
207 };
208 
209 }  // namespace
210 
createRestrictMaxRankPass()211 std::unique_ptr<OperationPass<func::FuncOp>> createRestrictMaxRankPass() {
212   return std::make_unique<RestrictMaxRankPass>();
213 }
214 
215 }  // namespace mhlo
216 }  // namespace mlir
217