1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <memory>
18 #include <utility>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Interfaces/InferTypeOpInterface.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34
35 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anone1fdca7d0111::ShapeReificationPattern36 explicit ShapeReificationPattern(MLIRContext *ctx)
37 : OpRewritePattern<shape::ShapeOfOp>(ctx) {
38 // Recursively reify until we hit an op that doesn't support it.
39 setHasBoundedRewriteRecursion();
40 }
41
matchAndRewritemlir::mhlo::__anone1fdca7d0111::ShapeReificationPattern42 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
43 PatternRewriter &rewriter) const override {
44 auto origin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
45 if (!origin) return failure();
46 SmallVector<Value, 1> reifications;
47 if (failed(origin.reifyReturnTypeShapes(rewriter, origin->getOperands(),
48 reifications))) {
49 return failure();
50 }
51 Value shape = reifications[op.getArg().cast<OpResult>().getResultNumber()];
52
53 // Insert cast, if needed.
54 if (shape.getType() != op.getType()) {
55 shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
56 }
57
58 rewriter.replaceOp(op, shape);
59 return success();
60 }
61 };
62
63 struct ShapeReificationThroughAssumingOpsPattern
64 : public OpRewritePattern<shape::AssumingOp> {
65 using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anone1fdca7d0111::ShapeReificationThroughAssumingOpsPattern66 LogicalResult matchAndRewrite(shape::AssumingOp aop,
67 PatternRewriter &rewriter) const override {
68 // Analyze in which results' values and shapes we are interested.
69 size_t numResults = aop->getNumResults();
70 SmallVector<SmallVector<shape::ShapeOfOp>> shapeUsersPerResult;
71 shapeUsersPerResult.reserve(numResults);
72 SmallVector<bool> hasNonShapeUsersPerResult;
73 hasNonShapeUsersPerResult.reserve(numResults);
74 for (Value result : aop.getResults()) {
75 auto &shapeUsers = shapeUsersPerResult.emplace_back();
76 auto &hasNonShapeUsers = hasNonShapeUsersPerResult.emplace_back(false);
77 for (Operation *user : result.getUsers()) {
78 if (auto sop = llvm::dyn_cast<shape::ShapeOfOp>(user)) {
79 shapeUsers.push_back(sop);
80 } else {
81 hasNonShapeUsers = true;
82 }
83 }
84 }
85
86 // Fail, if there is nothing to make progress on.
87 if (llvm::all_of(shapeUsersPerResult, [](auto it) { return it.empty(); }) &&
88 llvm::all_of(hasNonShapeUsersPerResult, [](auto it) { return it; })) {
89 return failure();
90 }
91
92 // Create a new assuming op.
93 auto newAop = rewriter.create<shape::AssumingOp>(
94 aop.getLoc(), aop.getWitness(), [&](OpBuilder &b, Location loc) {
95 // From the old assuming op, move all ops over to this new one, except
96 // the yield terminator.
97 Block *aopBody = aop.getBody();
98 auto yop =
99 llvm::cast<shape::AssumingYieldOp>(aopBody->getTerminator());
100 Block *newAopBody = b.getInsertionBlock();
101 auto &dstOps = newAopBody->getOperations();
102 auto &srcOps = aopBody->getOperations();
103 dstOps.splice(dstOps.begin(), srcOps, srcOps.begin(),
104 yop->getIterator());
105
106 // Collect all the values that have non-shape uses to yield them from
107 // the body. Also, create the needed `shape_of` ops at the end of the
108 // body and yield these results.
109 b.setInsertionPointToEnd(newAopBody);
110 SmallVector<Value> results;
111 SmallVector<Value> shapeResults;
112 for (const auto &it : llvm::enumerate(yop.getOperands())) {
113 if (hasNonShapeUsersPerResult[it.index()]) {
114 results.push_back(it.value());
115 }
116 if (!shapeUsersPerResult[it.index()].empty()) {
117 shapeResults.push_back(
118 b.create<shape::ShapeOfOp>(loc, it.value()));
119 }
120 }
121 results.append(shapeResults);
122 return results;
123 });
124
125 // Find the replacement values for the old assuming op.
126 size_t i = 0;
127 auto newAopResults = newAop.getResults();
128 auto replacement = llvm::to_vector<8>(llvm::map_range(
129 hasNonShapeUsersPerResult, [&](bool hasNonShapeUses) -> Value {
130 return hasNonShapeUses ? newAopResults[i++] : nullptr;
131 }));
132
133 // Replace all the shape uses with the shape values from the new assuming
134 // region.
135 for (const auto &shapeUsers : shapeUsersPerResult) {
136 if (shapeUsers.empty()) continue;
137 for (shape::ShapeOfOp sop : shapeUsers) {
138 rewriter.replaceOp(sop, newAopResults[i]);
139 }
140 i++;
141 }
142 assert(i == newAopResults.size() &&
143 "expect to use all results of the new assuming op");
144
145 // Finally, replace the old assuming op.
146 rewriter.replaceOp(aop, replacement);
147 return success();
148 }
149 };
150
151 struct ShapeReificationPass
152 : public ShapeReificationPassBase<ShapeReificationPass> {
getDependentDialectsmlir::mhlo::__anone1fdca7d0111::ShapeReificationPass153 void getDependentDialects(DialectRegistry ®istry) const override {
154 registry.insert<shape::ShapeDialect>();
155 }
156
runOnOperationmlir::mhlo::__anone1fdca7d0111::ShapeReificationPass157 void runOnOperation() override {
158 // Collect patterns.
159 MLIRContext *ctx = &getContext();
160 RewritePatternSet patterns(ctx);
161 populateShapeReificationPatterns(ctx, &patterns);
162
163 // Apply patterns from the bottom up. This ensures to need no more than one
164 // iteration.
165 GreedyRewriteConfig cfg;
166 cfg.useTopDownTraversal = false;
167 func::FuncOp f = getOperation();
168 if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns), cfg))) {
169 return signalPassFailure();
170 }
171 }
172 };
173
174 } // namespace
175
populateShapeReificationPatterns(MLIRContext * ctx,RewritePatternSet * patterns)176 void populateShapeReificationPatterns(MLIRContext *ctx,
177 RewritePatternSet *patterns) {
178 // clang-format off
179 patterns->add<
180 ShapeReificationPattern,
181 ShapeReificationThroughAssumingOpsPattern>(ctx);
182 // clang-format on
183 }
184
createShapeReificationPass()185 std::unique_ptr<OperationPass<func::FuncOp>> createShapeReificationPass() {
186 return std::make_unique<ShapeReificationPass>();
187 }
188
189 } // namespace mhlo
190 } // namespace mlir
191