1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <utility>
18 
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/SCF/IR/SCF.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Block.h"
36 #include "mlir/IR/BlockAndValueMapping.h"
37 #include "mlir/IR/BuiltinOps.h"
38 #include "mlir/IR/BuiltinTypes.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Interfaces/InferTypeOpInterface.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
45 
46 namespace mlir {
47 
48 /// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
49 /// `mlir::Value`s.
operator <(const Value & lhs,const Value & rhs)50 static bool operator<(const Value &lhs, const Value &rhs) {
51   return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
52 }
53 
54 namespace mhlo {
55 namespace {
56 
57 /// Identify clusters of operations that can be rank-specialized together. The
58 /// required traits for clustered operations are:
59 ///   - Element-wise: All operations in the group must be element-wise. This
60 ///     allows to reshape operands before applying the operations as well as
61 ///     reshaping the result to the desired shape afterwards. This way, we can,
62 ///     e.g., apply unary ops to a completely flattened operand and restore the
63 ///     original shape afterwards.
64 ///   - Broadcasting semantics: All operations must implement broadcasting
65 ///     semantics. Most importantly, this allows extending operand shapes such
66 ///     that they match in rank. Operations that require all their operands to
67 ///     be of the same shape also fulfill this requirement.
68 ///   - Shape reification: All operations must implement
69 ///     `InferShapedTypeOpInterface`. This is later needed to compute and to
70 ///     restore the desired result shape.
71 
isClusterable(Operation * op)72 bool isClusterable(Operation *op) {
73   if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
74   if (op->getNumOperands() == 0) return false;
75   return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
76           op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
77          op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
78 }
79 
80 struct RankSpecializationClusterPattern : public RewritePattern {
RankSpecializationClusterPatternmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPattern81   explicit RankSpecializationClusterPattern(MLIRContext *ctx)
82       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
83 
matchAndRewritemlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPattern84   LogicalResult matchAndRewrite(Operation *op,
85                                 PatternRewriter &rewriter) const override {
86     // Only apply to operations that have not been clustered yet.
87     if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
88       return failure();
89     }
90 
91     // Only cluster when rank specialization is needed.
92     if (!isClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
93           return ty.isa<UnrankedTensorType>();
94         })) {
95       return failure();
96     }
97 
98     // Collect all collectively rank specializable ops.
99     SmallVector<Operation *, 16> cluster;
100     llvm::SmallSet<Value, 16> operandSet;
101     llvm::SmallSet<Value, 16> resultSet;
102 
103     Operation *rootOp = op;
104     while (rootOp->getNextNode() != nullptr &&
105            isClusterable(rootOp->getNextNode()))
106       rootOp = rootOp->getNextNode();
107 
108     Operation *it = rootOp;
109     while (it != nullptr && isClusterable(it)) {
110       // Find results that escape the cluster.
111       for (OpOperand &use : it->getUses()) {
112         if (!llvm::is_contained(cluster, use.getOwner()))
113           resultSet.insert(use.get());
114       }
115 
116       // Update cluster operands.
117       for (OpResult v : it->getResults()) operandSet.erase(Value(v));
118       for (OpOperand &v : it->getOpOperands()) operandSet.insert(v.get());
119 
120       cluster.push_back(it);
121       it = it->getPrevNode();
122     }
123 
124     // Create `RankSpecializationClusterOp`.
125     auto operands = llvm::to_vector<16>(operandSet);
126     auto results = llvm::to_vector<16>(resultSet);
127     auto resultTypes = llvm::to_vector<16>(
128         llvm::map_range(resultSet, [](Value v) { return v.getType(); }));
129     Location loc = op->getLoc();
130     auto clusterOp = rewriter.create<chlo::RankSpecializationClusterOp>(
131         loc, resultTypes, operands);
132 
133     // Create body block.
134     auto operandTypes = llvm::to_vector<16>(
135         llvm::map_range(operandSet, [](Value v) { return v.getType(); }));
136     Block *block =
137         rewriter.createBlock(&clusterOp.body(), {}, operandTypes,
138                              SmallVector<Location>(operandTypes.size(), loc));
139 
140     // Copy operations into the body.
141     BlockAndValueMapping bvm;
142     for (auto it : llvm::zip(operands, block->getArguments()))
143       bvm.map(std::get<0>(it), std::get<1>(it));
144     rewriter.setInsertionPointToStart(block);
145     for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
146 
147     // Create `RankSpecializationClusterYieldOp`.
148     auto mappedResults = llvm::to_vector<16>(
149         llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
150     rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc, mappedResults);
151 
152     // Replace original ops with the new results.
153     for (auto it : llvm::zip(results, clusterOp.results()))
154       bvm.map(std::get<0>(it), std::get<1>(it));
155     for (Operation *it : cluster) {
156       if (it->getUses().empty()) {
157         rewriter.eraseOp(it);
158         continue;
159       }
160       auto replacements = llvm::to_vector<16>(llvm::map_range(
161           it->getResults(), [&](Value v) { return bvm.lookup(v); }));
162       rewriter.replaceOp(it, replacements);
163     }
164 
165     return success();
166   }
167 };
168 
169 struct MergeRankSpecializationClusterOpsPattern
170     : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
171   using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
172 
matchAndRewritemlir::mhlo::__anonfd3fb1340111::MergeRankSpecializationClusterOpsPattern173   LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
174                                 PatternRewriter &rewriter) const override {
175     auto precedingOp =
176         llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
177             op->getPrevNode());
178     if (!precedingOp) return failure();
179     Block *body = op.getBody();
180     Block *precedingBody = precedingOp.getBody();
181     auto yieldOp = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
182         op.getBody()->getTerminator());
183     auto precedingYieldOp =
184         llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
185             precedingOp.getBody()->getTerminator());
186 
187     // Merge cluster operands. Consider only those operands of the second
188     // cluster that do not originate in the preceding cluster.
189     SmallVector<Value, 8> newOperands;
190     for (Value v : precedingOp.operands()) newOperands.push_back(v);
191     for (Value v : op.operands()) {
192       if (v.getDefiningOp() != precedingOp &&
193           !llvm::is_contained(precedingOp.operands(), v)) {
194         newOperands.push_back(v);
195       }
196     }
197 
198     // Merge cluster results. Consider only those results of the preceding
199     // cluster that are not exclusively used as operands to the second cluster.
200     SmallVector<Value, 8> newUnmappedResults;
201     for (auto it :
202          llvm::zip(precedingOp.results(), precedingYieldOp.results())) {
203       Value result, innerResult;
204       std::tie(result, innerResult) = it;
205       if (!llvm::all_of(result.getUsers(),
206                         [&](Operation *user) { return user == op; })) {
207         newUnmappedResults.push_back(innerResult);
208       }
209     }
210     for (Value v : yieldOp.results()) newUnmappedResults.push_back(v);
211 
212     // Create merged cluster op.
213     rewriter.setInsertionPoint(precedingOp);
214     auto loc = op.getLoc();
215     auto resultTypes = llvm::to_vector<16>(llvm::map_range(
216         newUnmappedResults, [](Value v) { return v.getType(); }));
217     auto newOp = rewriter.create<chlo::RankSpecializationClusterOp>(
218         loc, resultTypes, newOperands);
219     auto operandTypes = llvm::to_vector<16>(
220         llvm::map_range(newOperands, [](Value v) { return v.getType(); }));
221     Block *newBody =
222         rewriter.createBlock(&newOp.body(), {}, operandTypes,
223                              SmallVector<Location>(operandTypes.size(), loc));
224     rewriter.setInsertionPointToStart(newBody);
225 
226     // Map operands and copy operations of the preceding cluster into the new
227     // body.
228     BlockAndValueMapping bvm;
229     for (const auto &it : llvm::enumerate(precedingBody->getArguments()))
230       bvm.map(it.value(), newBody->getArgument(it.index()));
231     for (Operation &nestedOp : precedingBody->without_terminator())
232       rewriter.clone(nestedOp, bvm);
233 
234     // Map operands and copy operations of the second cluster. If they result
235     // from the preceeding cluster, we can simply map the corresponding value
236     // internally.
237     for (auto it : llvm::zip(body->getArguments(), op.operands())) {
238       Value blockArg, operand;
239       std::tie(blockArg, operand) = it;
240       if (operand.getDefiningOp() == precedingOp) {
241         auto where = llvm::find(precedingOp.results(), operand);
242         assert(where.getBase() != nullptr && "expected to find ");
243         bvm.map(blockArg,
244                 bvm.lookup(precedingYieldOp.getOperand(where.getIndex())));
245       } else {
246         auto where = llvm::find(newOp.operands(), operand);
247         bvm.map(blockArg, newBody->getArgument(where.getIndex()));
248       }
249     }
250     for (Operation &nestedOp : body->without_terminator()) {
251       rewriter.clone(nestedOp, bvm);
252     }
253 
254     // Yield inner results.
255     rewriter.create<chlo::RankSpecializationClusterYieldOp>(
256         loc,
257         llvm::to_vector<16>(llvm::map_range(newUnmappedResults, [&](Value v) {
258           return bvm.lookupOrDefault(v);
259         })));
260 
261     // Replace the two cluster ops with the new corresponding results.
262     SmallVector<Value, 8> precedingOpReplacements;
263     int64_t i = 0;
264     for (Value result : precedingOp.results()) {
265       Value replacement = nullptr;
266       if (!llvm::all_of(result.getUsers(),
267                         [&](Operation *user) { return user == op; })) {
268         replacement = newOp->getResult(i++);
269       }
270       precedingOpReplacements.push_back(replacement);
271     }
272     ValueRange opReplacements = newOp.results().take_back(op.getNumResults());
273     rewriter.replaceOp(op, opReplacements);
274     rewriter.replaceOp(precedingOp, precedingOpReplacements);
275 
276     return success();
277   }
278 };
279 
280 struct RankSpecializationClusterPass
281     : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
getDependentDialectsmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPass282   void getDependentDialects(DialectRegistry &registry) const override {
283     registry.insert<mhlo::MhloDialect, chlo::ChloDialect>();
284   }
285 
runOnOperationmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPass286   void runOnOperation() override {
287     MLIRContext *ctx = &getContext();
288     RewritePatternSet patterns(ctx);
289     mhlo::populateRankSpecializationClusterPatterns(ctx, &patterns);
290     if (failed(applyPatternsAndFoldGreedily(getOperation(),
291                                             std::move(patterns)))) {
292       return signalPassFailure();
293     }
294   }
295 };
296 
297 /// Lower rank specialization cluster to SCF.
298 
isScalarTensorType(Type ty)299 bool isScalarTensorType(Type ty) {
300   auto rankedTy = ty.dyn_cast<RankedTensorType>();
301   return rankedTy && rankedTy.getRank() == 0;
302 }
303 
isScalarShapeType(Type ty)304 bool isScalarShapeType(Type ty) {
305   return ty.cast<RankedTensorType>().getDimSize(0) == 0;
306 }
307 
deriveRankedTensorTypes(Type ty,int64_t rank)308 Type deriveRankedTensorTypes(Type ty, int64_t rank) {
309   auto tensorTy = ty.dyn_cast<TensorType>();
310   if (!tensorTy) return ty;
311   SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
312   return RankedTensorType::get(shape, tensorTy.getElementType());
313 }
314 
deriveUnrankedTensorTypes(Type ty)315 Type deriveUnrankedTensorTypes(Type ty) {
316   if (auto rankedTy = ty.dyn_cast<RankedTensorType>())
317     return UnrankedTensorType::get(rankedTy.getElementType());
318   return ty;
319 }
320 
materializeRankedOperations(OpBuilder & b,Location loc,BlockAndValueMapping & bvm,chlo::RankSpecializationClusterOp op)321 SmallVector<Value, 8> materializeRankedOperations(
322     OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
323     chlo::RankSpecializationClusterOp op) {
324   // Create ranked operations.
325   for (Operation &nestedOp : op.getBody()->without_terminator()) {
326     auto mappedOperands = llvm::to_vector<4>(llvm::map_range(
327         nestedOp.getOperands(), [&](Value v) { return bvm.lookup(v); }));
328     int64_t targetRank = 0;
329     for (Value v : mappedOperands) {
330       targetRank =
331           std::max(targetRank, v.getType().cast<RankedTensorType>().getRank());
332     }
333     auto rankedResultTypes = llvm::to_vector<2>(
334         llvm::map_range(nestedOp.getResultTypes(), [targetRank](Type ty) {
335           return deriveRankedTensorTypes(ty, targetRank);
336         }));
337     OperationState rankedOpState(loc, nestedOp.getName().getStringRef(),
338                                  mappedOperands, rankedResultTypes,
339                                  nestedOp.getAttrs());
340     Operation *rankedOp = b.create(rankedOpState);
341     for (auto it : llvm::zip(nestedOp.getResults(), rankedOp->getResults()))
342       bvm.map(std::get<0>(it), std::get<1>(it));
343   }
344 
345   // Collect ranked results.
346   auto yieldOp = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
347       op.getBody()->getTerminator());
348   return llvm::to_vector<8>(llvm::map_range(
349       yieldOp.results(), [&](Value v) { return bvm.lookup(v); }));
350 }
351 
materializeFinalReshape(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange unshapedResults)352 SmallVector<Value, 8> materializeFinalReshape(
353     PatternRewriter &rewriter, Location loc,
354     chlo::RankSpecializationClusterOp op, ValueRange unshapedResults) {
355   auto yieldOp = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
356       op.getBody()->getTerminator());
357   assert(unshapedResults.size() == 1 && yieldOp.results().size() == 1 &&
358          "Currently, rank specialization supports only one result.");
359 
360   // Reify result shape.
361   Operation *lastOpBeforeShapeReification = op->getPrevNode();
362   SmallVector<Value, 1> resultShape;
363   Value originalResult = yieldOp.results().front();
364   auto originalResultIface =
365       llvm::cast<InferShapedTypeOpInterface>(originalResult.getDefiningOp());
366   if (failed(originalResultIface.reifyReturnTypeShapes(
367           rewriter, originalResultIface->getOperands(), resultShape))) {
368     return {};
369   }
370 
371   // Materialize final reshape.
372   Value unshapedResult = unshapedResults.front();
373   Value result = rewriter.create<mhlo::DynamicReshapeOp>(
374       loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult,
375       resultShape.front());
376 
377   // Reify shapes until they are independent of operations in the original
378   // cluster.
379   {
380     Operation *it = resultShape.front().getDefiningOp();
381     while (it != nullptr && it != lastOpBeforeShapeReification) {
382       bool advanced = false;
383       if (auto shapeOfOp = llvm::dyn_cast<shape::ShapeOfOp>(it)) {
384         Operation *def = shapeOfOp.getArg().getDefiningOp();
385         if (def && def->getBlock() == op.getBody()) {
386           // Resolve `shape_of` op because it still depends on operation in the
387           // original cluster.
388           OpBuilder::InsertionGuard guard(rewriter);
389           rewriter.setInsertionPoint(shapeOfOp);
390           SmallVector<Value, 1> tmpShape;
391           auto iface = llvm::cast<InferShapedTypeOpInterface>(def);
392           if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(),
393                                                  tmpShape)))
394             return {};
395           rewriter.replaceOp(shapeOfOp, tmpShape.front());
396 
397           // Continue, including the newly created operations.
398           it = tmpShape.front().getDefiningOp();
399           advanced = true;
400         }
401       }
402 
403       // Skip op, otherwise.
404       if (!advanced) it = it->getPrevNode();
405     }
406   }
407 
408   // Replace all remaining uses of the original cluster's block args.
409   for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) {
410     Value operand, barg;
411     std::tie(operand, barg) = it;
412     barg.replaceUsesWithIf(operand, [&](OpOperand &operand) {
413       return operand.getOwner()->getBlock() != op.getBody();
414     });
415   }
416 
417   return {result};
418 }
419 
materializeFlatShape(OpBuilder & b,Location loc,ValueRange sameShapes)420 Value materializeFlatShape(OpBuilder &b, Location loc, ValueRange sameShapes) {
421   assert(!sameShapes.empty() && "Expected at least one shape.");
422   Value shape = sameShapes.size() == 1
423                     ? sameShapes.front()
424                     : b.create<shape::AnyOp>(loc, sameShapes.front().getType(),
425                                              sameShapes);
426   return b.create<tensor::FromElementsOp>(
427       loc,
428       b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).getResult());
429 }
430 
materializeScalarRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,ValueRange nonScalarsOfSameShape,function_ref<void (OpBuilder &,Location)> elseBuilderFn)431 Value materializeScalarRankSpecializationCase(
432     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
433     const SmallVector<Value, 8> &shapes, ValueRange nonScalarsOfSameShape,
434     function_ref<void(OpBuilder &, Location)> elseBuilderFn) {
435   // Materialize predicate: All operands are scalars, except the expected
436   // non-scalars.
437   Value one = b.create<arith::ConstantIndexOp>(loc, 1);
438   Value allOthersAreScalar;
439   for (auto it : llvm::zip(op.operands(), shapes)) {
440     Value operand, shape;
441     std::tie(operand, shape) = it;
442     if (llvm::is_contained(nonScalarsOfSameShape, operand) ||
443         isScalarTensorType(operand.getType())) {
444       continue;
445     }
446     auto literal = b.create<arith::CmpIOp>(
447         loc, arith::CmpIPredicate::eq,
448         b.create<shape::NumElementsOp>(loc, shape), one);
449     allOthersAreScalar =
450         allOthersAreScalar
451             ? b.create<mlir::arith::AndIOp>(loc, allOthersAreScalar, literal)
452                   .getResult()
453             : literal.getResult();
454   }
455 
456   auto ifOp = b.create<scf::IfOp>(
457       loc, op->getResultTypes(), allOthersAreScalar,
458       [&](OpBuilder &b, Location loc) {
459         // Compute flat non-scalar shape.
460         SmallVector<Value, 4> nonScalarShapes;
461         for (auto it : llvm::zip(op.operands(), shapes)) {
462           Value operand, shape;
463           std::tie(operand, shape) = it;
464           if (llvm::is_contained(nonScalarsOfSameShape, operand))
465             nonScalarShapes.push_back(shape);
466         }
467         Value flatShape = materializeFlatShape(b, loc, nonScalarShapes);
468 
469         // Derive ranked operands.
470         auto rankedOperands =
471             llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
472               if (isScalarTensorType(v.getType())) return v;
473               if (!llvm::is_contained(nonScalarsOfSameShape, v)) {
474                 return b
475                     .create<mhlo::ReshapeOp>(
476                         loc, deriveRankedTensorTypes(v.getType(), /*rank=*/0),
477                         v)
478                     .getResult();
479               }
480               return b
481                   .create<mhlo::DynamicReshapeOp>(
482                       loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
483                       flatShape)
484                   .getResult();
485             }));
486 
487         // Materialize ranked variants for the element-wise operations.
488         BlockAndValueMapping bvm;
489         for (auto it : llvm::zip(op.getBody()->getArguments(), rankedOperands))
490           bvm.map(std::get<0>(it), std::get<1>(it));
491         Value unshapedResult =
492             materializeRankedOperations(b, loc, bvm, op).front();
493 
494         // Return as unranked tensor for compatibility with the other cases.
495         b.create<scf::YieldOp>(
496             loc, b.create<tensor::CastOp>(
497                       loc, deriveUnrankedTensorTypes(unshapedResult.getType()),
498                       unshapedResult)
499                      .getDest());
500       },
501       elseBuilderFn);
502 
503   return ifOp.getResults().front();
504 }
505 
materializeEqualShapesRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,function_ref<void (OpBuilder &,Location)> elseBuilderFn)506 Value materializeEqualShapesRankSpecializationCase(
507     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
508     const SmallVector<Value, 8> &shapes,
509     function_ref<void(OpBuilder &, Location)> elseBuilderFn) {
510   // Materialize all shapes equal predicate.
511   Value allShapesEqOrScalar;
512   auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range(
513       shapes, [](Value v) { return !isScalarShapeType(v.getType()); }));
514   assert(
515       nonScalarShapes.size() >= 2 &&
516       "Equal shapes strategy requires at least two non-scalar operand shapes.");
517   for (Value s : llvm::drop_begin(nonScalarShapes)) {
518     auto literal = b.create<shape::ShapeEqOp>(loc, nonScalarShapes.front(), s);
519     allShapesEqOrScalar =
520         allShapesEqOrScalar
521             ? b.create<mlir::arith::AndIOp>(loc, allShapesEqOrScalar, literal)
522                   .getResult()
523             : literal;
524   }
525 
526   auto ifOp = b.create<scf::IfOp>(
527       loc, op->getResultTypes(), allShapesEqOrScalar,
528       [&](OpBuilder &b, Location loc) {
529         // Flatten non-scalar operands.
530         Value flatShape = materializeFlatShape(b, loc, nonScalarShapes);
531         auto flatOperands =
532             llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
533               if (isScalarTensorType(v.getType())) return v;
534               return b
535                   .create<mhlo::DynamicReshapeOp>(
536                       loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
537                       flatShape)
538                   .result();
539             }));
540 
541         // Materialize ranked variants for the element-wise operations.
542         BlockAndValueMapping bvm;
543         for (auto it : llvm::zip(op.getBody()->getArguments(), flatOperands))
544           bvm.map(std::get<0>(it), std::get<1>(it));
545         Value unshapedResult =
546             materializeRankedOperations(b, loc, bvm, op).front();
547 
548         // Return as unranked tensor for compatibility with the other cases.
549         b.create<scf::YieldOp>(
550             loc, b.create<tensor::CastOp>(
551                       loc, deriveUnrankedTensorTypes(unshapedResult.getType()),
552                       unshapedResult)
553                      .getDest());
554       },
555       elseBuilderFn);
556 
557   return ifOp.getResults().front();
558 }
559 
materializeTargetRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t targetRank)560 Value materializeTargetRankSpecializationCase(
561     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
562     const SmallVector<Value, 8> &shapes, int64_t targetRank) {
563   // Reshape unranked operands to match the target rank.
564   RankedTensorType extentTensorTy =
565       shape::getExtentTensorType(b.getContext(), targetRank);
566   Value allOnesShape = b.create<shape::ConstShapeOp>(
567       loc, extentTensorTy,
568       mlir::DenseIntElementsAttr::get(extentTensorTy,
569                                       SmallVector<int64_t, 6>(targetRank, 1)));
570   SmallVector<Value, 8> rankedOperands;
571   for (auto it : llvm::zip(op.operands(), shapes)) {
572     Value operand, shape;
573     std::tie(operand, shape) = it;
574     if (operand.getType().isa<RankedTensorType>()) {
575       rankedOperands.push_back(operand);
576       continue;
577     }
578     Value rankedShape = b.create<tensor::CastOp>(
579         loc, extentTensorTy,
580         b.create<shape::BroadcastOp>(loc,
581                                      shape::getExtentTensorType(b.getContext()),
582                                      shape, allOnesShape,
583                                      /*error=*/nullptr));
584     rankedOperands.push_back(b.create<mhlo::DynamicReshapeOp>(
585         loc, deriveRankedTensorTypes(operand.getType(), targetRank), operand,
586         rankedShape));
587   }
588 
589   // Materialize ranked versions of the element-wise operations.
590   BlockAndValueMapping bvm;
591   for (auto it : llvm::zip(op.body().front().getArguments(), rankedOperands))
592     bvm.map(std::get<0>(it), std::get<1>(it));
593 
594   // Return as unranked for compatibility with other target ranks.
595   auto unshapedResult = materializeRankedOperations(b, loc, bvm, op).front();
596   return b.create<tensor::CastOp>(
597       loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult);
598 }
599 
recusivelyMaterializeTargetRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,Value maxRank,int64_t minTargetRank,int64_t maxTargetRank)600 Value recusivelyMaterializeTargetRankSpecializationCases(
601     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
602     const SmallVector<Value, 8> &shapes, Value maxRank, int64_t minTargetRank,
603     int64_t maxTargetRank) {
604   Value condition = b.create<arith::CmpIOp>(
605       loc, arith::CmpIPredicate::ule, maxRank,
606       b.create<arith::ConstantIndexOp>(loc, minTargetRank));
607 
608   // If only a unique target rank is left, we can lower to an assert instead
609   // of the usual if operation.
610   if (minTargetRank == maxTargetRank) {
611     b.create<cf::AssertOp>(
612         loc, condition,
613         "Input for dynamic binary or n-ary op lowering was of "
614         "a rank greater than " +
615             std::to_string(maxTargetRank));
616     return materializeTargetRankSpecializationCase(b, loc, op, shapes,
617                                                    minTargetRank);
618   }
619 
620   // Materialize IR for the smallest considered target rank.
621   auto ifOp = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
622                                   /*withElseRegion=*/true);
623   auto thenBuilder = ifOp.getThenBodyBuilder();
624   thenBuilder.create<scf::YieldOp>(
625       loc, materializeTargetRankSpecializationCase(thenBuilder, loc, op, shapes,
626                                                    minTargetRank));
627 
628   // Recurse for all remaining target ranks.
629   auto elseBuilder = ifOp.getElseBodyBuilder();
630   elseBuilder.create<scf::YieldOp>(
631       loc, recusivelyMaterializeTargetRankSpecializationCases(
632                elseBuilder, loc, op, shapes, maxRank, minTargetRank + 1,
633                maxTargetRank));
634 
635   return ifOp.getResults().front();
636 }
637 
materializeGenericRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t maxTargetRank)638 Value materializeGenericRankSpecializationCases(
639     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
640     const SmallVector<Value, 8> &shapes, int64_t maxTargetRank) {
641   // Get the minimum broadcast shapes of the operands.
642   auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range(
643       shapes, [](Value v) { return !isScalarShapeType(v.getType()); }));
644   auto minBcastShapesOp = b.create<chlo::MinimumBroadcastShapesOp>(
645       loc,
646       SmallVector<Type, 8>(nonScalarShapes.size(),
647                            shape::getExtentTensorType(b.getContext())),
648       nonScalarShapes);
649 
650   // Find the maximum rank among the reduced operand shapes.
651   Value maxRank;
652   for (Value shape : minBcastShapesOp.results()) {
653     Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
654     if (!maxRank) {
655       maxRank = rank;
656     } else {
657       maxRank = b.create<mlir::arith::SelectOp>(
658           loc,
659           b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, maxRank,
660                                   rank),
661           maxRank, rank);
662     }
663   }
664 
665   // Collect reduced shapes.
666   SmallVector<Value, 8> reducedShapes;
667   auto it = minBcastShapesOp.result_begin();
668   for (Value s : shapes) {
669     if (isScalarShapeType(s.getType())) {
670       reducedShapes.push_back(s);
671     } else {
672       reducedShapes.push_back(*it++);
673     }
674   }
675 
676   // Materialize rank specialization for ranks 1, ...
677   return recusivelyMaterializeTargetRankSpecializationCases(
678       b, loc, op, reducedShapes, maxRank, /*minTargetRank=*/1, maxTargetRank);
679 }
680 
materializeDefaultRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t maxTargetRank)681 Value materializeDefaultRankSpecializationCases(
682     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
683     const SmallVector<Value, 8> &shapes, int64_t maxTargetRank) {
684   return materializeEqualShapesRankSpecializationCase(
685       b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
686         b.create<scf::YieldOp>(loc, materializeGenericRankSpecializationCases(
687                                         b, loc, op, shapes, maxTargetRank));
688       });
689 }
690 
691 SmallVector<Value, 8>
materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange nonScalarsOfSameShape)692 materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
693     PatternRewriter &rewriter, Location loc,
694     chlo::RankSpecializationClusterOp op, ValueRange nonScalarsOfSameShape) {
695   // Compute flat operand shape.
696   auto nonScalarShapes =
697       llvm::to_vector<4>(llvm::map_range(nonScalarsOfSameShape, [&](Value v) {
698         return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
699       }));
700   Value flatShape = materializeFlatShape(rewriter, loc, nonScalarShapes);
701 
702   // Materialize ranked variants for the element-wise operations.
703   BlockAndValueMapping bvm;
704   for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
705     Value operand;
706     Value bbArg;
707     std::tie(bbArg, operand) = it;
708     if (!isScalarTensorType(operand.getType())) {
709       assert(llvm::is_contained(nonScalarsOfSameShape, operand) &&
710              "Expected all non-scalars in the same shape equivalence class.");
711       operand = rewriter.create<mhlo::DynamicReshapeOp>(
712           loc, deriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
713           flatShape);
714     }
715     bvm.map(bbArg, operand);
716   }
717   SmallVector<Value, 8> unshapedResults =
718       materializeRankedOperations(rewriter, loc, bvm, op);
719 
720   // Restore the results' expected shape.
721   Value shape = nonScalarShapes.front();
722   return llvm::to_vector<8>(llvm::map_range(unshapedResults, [&](Value v) {
723     return rewriter
724         .create<mhlo::DynamicReshapeOp>(
725             loc, deriveUnrankedTensorTypes(v.getType()), v, shape)
726         .result();
727   }));
728 }
729 
materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,SmallVector<SmallVector<Value,4>,4> nonScalarEqs,int64_t maxTargetRank)730 Value materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
731     PatternRewriter &rewriter, Location loc,
732     chlo::RankSpecializationClusterOp op,
733     SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs, int64_t maxTargetRank) {
734   assert(nonScalarEqs.size() == 2 &&
735          "Expect two non-scalar equivalence classes.");
736   auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
737     return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
738   }));
739   ValueRange lhsNonScalarEqs = nonScalarEqs[0];
740   ValueRange rhsNonScalarEqs = nonScalarEqs[1];
741 
742   // Materialize all the different cases.
743   Value unshapedResult = materializeScalarRankSpecializationCase(
744       rewriter, loc, op, shapes, rhsNonScalarEqs,
745       [&](OpBuilder &b, Location loc) {
746         b.create<scf::YieldOp>(
747             loc, materializeScalarRankSpecializationCase(
748                      b, loc, op, shapes, lhsNonScalarEqs,
749                      [&](OpBuilder &b, Location loc) {
750                        b.create<scf::YieldOp>(
751                            loc, materializeDefaultRankSpecializationCases(
752                                     b, loc, op, shapes, maxTargetRank));
753                      }));
754       });
755 
756   // Materialize final reshape once and for all rank specialization cases.
757   return materializeFinalReshape(rewriter, loc, op, unshapedResult).front();
758 }
759 
760 // Materialize rank generic rank specialization.
materializeDefaultRankSpecialization(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,int64_t maxTargetRank)761 Value materializeDefaultRankSpecialization(PatternRewriter &rewriter,
762                                            Location loc,
763                                            chlo::RankSpecializationClusterOp op,
764                                            int64_t maxTargetRank) {
765   auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
766     return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
767   }));
768 
769   // Materialize all the different cases.
770   Value unshapedResult = materializeDefaultRankSpecializationCases(
771       rewriter, loc, op, shapes, maxTargetRank);
772 
773   // Materialize final reshape once and for all rank specialization cases.
774   return materializeFinalReshape(rewriter, loc, op, unshapedResult).front();
775 }
776 
777 // This is a very limited form of shape inference. It is correct but incomplete.
findNonScalarShapeEquivalences(chlo::RankSpecializationClusterOp op)778 SmallVector<SmallVector<Value, 4>, 4> findNonScalarShapeEquivalences(
779     chlo::RankSpecializationClusterOp op) {
780   llvm::EquivalenceClasses<Value> eqs;
781 
782   // Bridge the equivalences between operands and block arguments.
783   for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
784     eqs.unionSets(std::get<0>(it), std::get<1>(it));
785 
786   // Find equalities through `SameOperandsAndResultShape` trait.
787   auto unionSets = [&](ValueRange vs) {
788     if (vs.empty()) return;
789     Value repr = vs.front();
790     for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
791   };
792   for (Operation &nestedOp : op.getBody()->without_terminator()) {
793     if (nestedOp.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
794       unionSets(nestedOp.getOperands());
795       unionSets(nestedOp.getResults());
796       if (!nestedOp.getOperands().empty() && !nestedOp.getResults().empty())
797         eqs.unionSets(nestedOp.getResult(0), nestedOp.getOperand(0));
798     }
799   }
800 
801   // Find shape equalities through surrounding constraints.
802   if (auto assumingOp = op->getParentOfType<shape::AssumingOp>()) {
803     SmallVector<Operation *, 8> queue;
804     auto appendIfNotNull = [&](Operation *op) {
805       if (op != nullptr) queue.push_back(op);
806     };
807     appendIfNotNull(assumingOp.getWitness().getDefiningOp());
808     while (!queue.empty()) {
809       Operation *it = queue.pop_back_val();
810       if (auto assumingAllOp = llvm::dyn_cast<shape::AssumingAllOp>(it)) {
811         for (Value v : assumingAllOp.getInputs())
812           appendIfNotNull(v.getDefiningOp());
813       } else if (auto cstrEqOp = llvm::dyn_cast<shape::CstrEqOp>(it)) {
814         Value refArg;
815         for (Value v : cstrEqOp.getShapes()) {
816           if (auto shapeOfOp =
817                   dyn_cast_or_null<shape::ShapeOfOp>(v.getDefiningOp())) {
818             if (!refArg) {
819               refArg = shapeOfOp.getArg();
820             } else {
821               eqs.unionSets(refArg, shapeOfOp.getArg());
822             }
823           }
824         }
825       }
826     }
827   }
828 
829   // Find equalities through special knowledge of ops.
830   // TODO(frgossen): Remove this when these shape equalities can be inferred
831   // from surrounding shape constraints.
832   for (Operation &nestedOp : op.getBody()->without_terminator()) {
833     if (auto selectOp = llvm::dyn_cast<mhlo::SelectOp>(nestedOp)) {
834       unionSets(
835           {selectOp.on_true(), selectOp.on_false(), selectOp.getResult()});
836     } else if (auto clampOp = llvm::dyn_cast<mhlo::ClampOp>(nestedOp)) {
837       unionSets({clampOp.operand(), clampOp.getResult()});
838     }
839   }
840 
841   // Convert to a list-like equivalence class representation.
842   SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs;
843   for (Value v : op.operands()) {
844     if (isScalarTensorType(v.getType())) continue;
845     bool inserted = false;
846     for (auto &eqClass : nonScalarEqs) {
847       if (eqs.isEquivalent(eqClass.front(), v)) {
848         eqClass.push_back(v);
849         inserted = true;
850         break;
851       }
852     }
853     if (!inserted) nonScalarEqs.push_back(SmallVector<Value, 4>({v}));
854   }
855 
856   return nonScalarEqs;
857 }
858 
859 struct LowerRankSpecializationClusterPattern
860     : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
LowerRankSpecializationClusterPatternmlir::mhlo::__anonfd3fb1340111::LowerRankSpecializationClusterPattern861   LowerRankSpecializationClusterPattern(MLIRContext *ctx, int64_t maxTargetRank)
862       : OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
863         maxTargetRank(maxTargetRank) {}
864 
matchAndRewritemlir::mhlo::__anonfd3fb1340111::LowerRankSpecializationClusterPattern865   LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
866                                 PatternRewriter &rewriter) const override {
867     // Restoring the result shape currently relies on all operands being used
868     // for a single result. The result shape is then the broadcasted shape of
869     // all operands.
870     if (op.getNumResults() != 1) return failure();
871 
872     // If there is only a single non-scalar shape equivalence class, we can
873     // flatten that operands completely.
874     SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs =
875         findNonScalarShapeEquivalences(op);
876     Location loc = op.getLoc();
877     if (nonScalarEqs.size() == 1) {
878       rewriter.replaceOp(
879           op,
880           materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
881               rewriter, loc, op, nonScalarEqs.front()));
882       return success();
883     }
884 
885     // If there are exactly two non-scalar shape equivalence classes, we can
886     // consider two extra cases: If either of the operand classes turns out to
887     // be all-scalars at runtime, we can, again, flatten all operands.
888     if (nonScalarEqs.size() == 2) {
889       rewriter.replaceOp(
890           op,
891           materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
892               rewriter, loc, op, nonScalarEqs, maxTargetRank));
893       return success();
894     }
895 
896     // For all other cases, reshape the operands to match in rank, apply the
897     // operation, and restore the expected shape.
898     rewriter.replaceOp(op, materializeDefaultRankSpecialization(
899                                rewriter, loc, op, maxTargetRank));
900     return success();
901   }
902 
903  private:
904   int64_t maxTargetRank;
905 };
906 
907 struct RankSpecializationToSCFPass
908     : public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
RankSpecializationToSCFPassmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass909   explicit RankSpecializationToSCFPass(int64_t maxTargetRank)
910       : RankSpecializationToSCFPassBase<
911             RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
912     this->max_target_rank_ = maxTargetRank;
913   }
914 
getDependentDialectsmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass915   void getDependentDialects(DialectRegistry &registry) const override {
916     registry.insert<mhlo::MhloDialect, chlo::ChloDialect, func::FuncDialect,
917                     shape::ShapeDialect, scf::SCFDialect>();
918   }
919 
runOnOperationmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass920   void runOnOperation() override {
921     MLIRContext *ctx = &getContext();
922     RewritePatternSet patterns(ctx);
923     populateRankSpecializationToSCFPatterns(ctx, &patterns,
924                                             this->max_target_rank_);
925     if (failed(applyPatternsAndFoldGreedily(getOperation(),
926                                             std::move(patterns)))) {
927       return signalPassFailure();
928     }
929   }
930 };
931 
932 }  // namespace
933 
populateRankSpecializationClusterPatterns(MLIRContext * context,RewritePatternSet * patterns)934 void populateRankSpecializationClusterPatterns(MLIRContext *context,
935                                                RewritePatternSet *patterns) {
936   patterns->add<MergeRankSpecializationClusterOpsPattern,
937                 RankSpecializationClusterPattern>(context);
938 }
939 
populateRankSpecializationToSCFPatterns(MLIRContext * context,RewritePatternSet * patterns,int64_t maxTargetRank)940 void populateRankSpecializationToSCFPatterns(MLIRContext *context,
941                                              RewritePatternSet *patterns,
942                                              int64_t maxTargetRank) {
943   patterns->add<LowerRankSpecializationClusterPattern>(context, maxTargetRank);
944   shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
945   shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context);
946   shape::AnyOp::getCanonicalizationPatterns(*patterns, context);
947 }
948 
949 std::unique_ptr<OperationPass<func::FuncOp>>
createRankSpecializationClusterPass()950 createRankSpecializationClusterPass() {
951   return std::make_unique<RankSpecializationClusterPass>();
952 }
953 
createRankSpecializationToSCFPass(int64_t maxTargetRank)954 std::unique_ptr<OperationPass<func::FuncOp>> createRankSpecializationToSCFPass(
955     int64_t maxTargetRank) {
956   return std::make_unique<RankSpecializationToSCFPass>(maxTargetRank);
957 }
958 
959 }  // namespace mhlo
960 }  // namespace mlir
961