1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Interfaces/InferTypeOpInterface.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34 
35 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anone1fdca7d0111::ShapeReificationPattern36   explicit ShapeReificationPattern(MLIRContext *ctx)
37       : OpRewritePattern<shape::ShapeOfOp>(ctx) {
38     // Recursively reify until we hit an op that doesn't support it.
39     setHasBoundedRewriteRecursion();
40   }
41 
matchAndRewritemlir::mhlo::__anone1fdca7d0111::ShapeReificationPattern42   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
43                                 PatternRewriter &rewriter) const override {
44     auto origin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
45     if (!origin) return failure();
46     SmallVector<Value, 1> reifications;
47     if (failed(origin.reifyReturnTypeShapes(rewriter, origin->getOperands(),
48                                             reifications))) {
49       return failure();
50     }
51     Value shape = reifications[op.getArg().cast<OpResult>().getResultNumber()];
52 
53     // Insert cast, if needed.
54     if (shape.getType() != op.getType()) {
55       shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
56     }
57 
58     rewriter.replaceOp(op, shape);
59     return success();
60   }
61 };
62 
63 struct ShapeReificationThroughAssumingOpsPattern
64     : public OpRewritePattern<shape::AssumingOp> {
65   using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anone1fdca7d0111::ShapeReificationThroughAssumingOpsPattern66   LogicalResult matchAndRewrite(shape::AssumingOp aop,
67                                 PatternRewriter &rewriter) const override {
68     // Analyze in which results' values and shapes we are interested.
69     size_t numResults = aop->getNumResults();
70     SmallVector<SmallVector<shape::ShapeOfOp>> shapeUsersPerResult;
71     shapeUsersPerResult.reserve(numResults);
72     SmallVector<bool> hasNonShapeUsersPerResult;
73     hasNonShapeUsersPerResult.reserve(numResults);
74     for (Value result : aop.getResults()) {
75       auto &shapeUsers = shapeUsersPerResult.emplace_back();
76       auto &hasNonShapeUsers = hasNonShapeUsersPerResult.emplace_back(false);
77       for (Operation *user : result.getUsers()) {
78         if (auto sop = llvm::dyn_cast<shape::ShapeOfOp>(user)) {
79           shapeUsers.push_back(sop);
80         } else {
81           hasNonShapeUsers = true;
82         }
83       }
84     }
85 
86     // Fail, if there is nothing to make progress on.
87     if (llvm::all_of(shapeUsersPerResult, [](auto it) { return it.empty(); }) &&
88         llvm::all_of(hasNonShapeUsersPerResult, [](auto it) { return it; })) {
89       return failure();
90     }
91 
92     // Create a new assuming op.
93     auto newAop = rewriter.create<shape::AssumingOp>(
94         aop.getLoc(), aop.getWitness(), [&](OpBuilder &b, Location loc) {
95           // From the old assuming op, move all ops over to this new one, except
96           // the yield terminator.
97           Block *aopBody = aop.getBody();
98           auto yop =
99               llvm::cast<shape::AssumingYieldOp>(aopBody->getTerminator());
100           Block *newAopBody = b.getInsertionBlock();
101           auto &dstOps = newAopBody->getOperations();
102           auto &srcOps = aopBody->getOperations();
103           dstOps.splice(dstOps.begin(), srcOps, srcOps.begin(),
104                         yop->getIterator());
105 
106           // Collect all the values that have non-shape uses to yield them from
107           // the body. Also, create the needed `shape_of` ops at the end of the
108           // body and yield these results.
109           b.setInsertionPointToEnd(newAopBody);
110           SmallVector<Value> results;
111           SmallVector<Value> shapeResults;
112           for (const auto &it : llvm::enumerate(yop.getOperands())) {
113             if (hasNonShapeUsersPerResult[it.index()]) {
114               results.push_back(it.value());
115             }
116             if (!shapeUsersPerResult[it.index()].empty()) {
117               shapeResults.push_back(
118                   b.create<shape::ShapeOfOp>(loc, it.value()));
119             }
120           }
121           results.append(shapeResults);
122           return results;
123         });
124 
125     // Find the replacement values for the old assuming op.
126     size_t i = 0;
127     auto newAopResults = newAop.getResults();
128     auto replacement = llvm::to_vector<8>(llvm::map_range(
129         hasNonShapeUsersPerResult, [&](bool hasNonShapeUses) -> Value {
130           return hasNonShapeUses ? newAopResults[i++] : nullptr;
131         }));
132 
133     // Replace all the shape uses with the shape values from the new assuming
134     // region.
135     for (const auto &shapeUsers : shapeUsersPerResult) {
136       if (shapeUsers.empty()) continue;
137       for (shape::ShapeOfOp sop : shapeUsers) {
138         rewriter.replaceOp(sop, newAopResults[i]);
139       }
140       i++;
141     }
142     assert(i == newAopResults.size() &&
143            "expect to use all results of the new assuming op");
144 
145     // Finally, replace the old assuming op.
146     rewriter.replaceOp(aop, replacement);
147     return success();
148   }
149 };
150 
151 struct ShapeReificationPass
152     : public ShapeReificationPassBase<ShapeReificationPass> {
getDependentDialectsmlir::mhlo::__anone1fdca7d0111::ShapeReificationPass153   void getDependentDialects(DialectRegistry &registry) const override {
154     registry.insert<shape::ShapeDialect>();
155   }
156 
runOnOperationmlir::mhlo::__anone1fdca7d0111::ShapeReificationPass157   void runOnOperation() override {
158     // Collect patterns.
159     MLIRContext *ctx = &getContext();
160     RewritePatternSet patterns(ctx);
161     populateShapeReificationPatterns(ctx, &patterns);
162 
163     // Apply patterns from the bottom up. This ensures to need no more than one
164     // iteration.
165     GreedyRewriteConfig cfg;
166     cfg.useTopDownTraversal = false;
167     func::FuncOp f = getOperation();
168     if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns), cfg))) {
169       return signalPassFailure();
170     }
171   }
172 };
173 
174 }  // namespace
175 
populateShapeReificationPatterns(MLIRContext * ctx,RewritePatternSet * patterns)176 void populateShapeReificationPatterns(MLIRContext *ctx,
177                                       RewritePatternSet *patterns) {
178   // clang-format off
179   patterns->add<
180       ShapeReificationPattern,
181       ShapeReificationThroughAssumingOpsPattern>(ctx);
182   // clang-format on
183 }
184 
createShapeReificationPass()185 std::unique_ptr<OperationPass<func::FuncOp>> createShapeReificationPass() {
186   return std::make_unique<ShapeReificationPass>();
187 }
188 
189 }  // namespace mhlo
190 }  // namespace mlir
191