1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <utility>
18
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/SCF/IR/SCF.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Block.h"
36 #include "mlir/IR/BlockAndValueMapping.h"
37 #include "mlir/IR/BuiltinOps.h"
38 #include "mlir/IR/BuiltinTypes.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Interfaces/InferTypeOpInterface.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
45
46 namespace mlir {
47
48 /// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
49 /// `mlir::Value`s.
operator <(const Value & lhs,const Value & rhs)50 static bool operator<(const Value &lhs, const Value &rhs) {
51 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
52 }
53
54 namespace mhlo {
55 namespace {
56
57 /// Identify clusters of operations that can be rank-specialized together. The
58 /// required traits for clustered operations are:
59 /// - Element-wise: All operations in the group must be element-wise. This
60 /// allows to reshape operands before applying the operations as well as
61 /// reshaping the result to the desired shape afterwards. This way, we can,
62 /// e.g., apply unary ops to a completely flattened operand and restore the
63 /// original shape afterwards.
64 /// - Broadcasting semantics: All operations must implement broadcasting
65 /// semantics. Most importantly, this allows extending operand shapes such
66 /// that they match in rank. Operations that require all their operands to
67 /// be of the same shape also fulfill this requirement.
68 /// - Shape reification: All operations must implement
69 /// `InferShapedTypeOpInterface`. This is later needed to compute and to
70 /// restore the desired result shape.
71
isClusterable(Operation * op)72 bool isClusterable(Operation *op) {
73 if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
74 if (op->getNumOperands() == 0) return false;
75 return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
76 op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
77 op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
78 }
79
80 struct RankSpecializationClusterPattern : public RewritePattern {
RankSpecializationClusterPatternmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPattern81 explicit RankSpecializationClusterPattern(MLIRContext *ctx)
82 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
83
matchAndRewritemlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPattern84 LogicalResult matchAndRewrite(Operation *op,
85 PatternRewriter &rewriter) const override {
86 // Only apply to operations that have not been clustered yet.
87 if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
88 return failure();
89 }
90
91 // Only cluster when rank specialization is needed.
92 if (!isClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
93 return ty.isa<UnrankedTensorType>();
94 })) {
95 return failure();
96 }
97
98 // Collect all collectively rank specializable ops.
99 SmallVector<Operation *, 16> cluster;
100 llvm::SmallSet<Value, 16> operandSet;
101 llvm::SmallSet<Value, 16> resultSet;
102
103 Operation *rootOp = op;
104 while (rootOp->getNextNode() != nullptr &&
105 isClusterable(rootOp->getNextNode()))
106 rootOp = rootOp->getNextNode();
107
108 Operation *it = rootOp;
109 while (it != nullptr && isClusterable(it)) {
110 // Find results that escape the cluster.
111 for (OpOperand &use : it->getUses()) {
112 if (!llvm::is_contained(cluster, use.getOwner()))
113 resultSet.insert(use.get());
114 }
115
116 // Update cluster operands.
117 for (OpResult v : it->getResults()) operandSet.erase(Value(v));
118 for (OpOperand &v : it->getOpOperands()) operandSet.insert(v.get());
119
120 cluster.push_back(it);
121 it = it->getPrevNode();
122 }
123
124 // Create `RankSpecializationClusterOp`.
125 auto operands = llvm::to_vector<16>(operandSet);
126 auto results = llvm::to_vector<16>(resultSet);
127 auto resultTypes = llvm::to_vector<16>(
128 llvm::map_range(resultSet, [](Value v) { return v.getType(); }));
129 Location loc = op->getLoc();
130 auto clusterOp = rewriter.create<chlo::RankSpecializationClusterOp>(
131 loc, resultTypes, operands);
132
133 // Create body block.
134 auto operandTypes = llvm::to_vector<16>(
135 llvm::map_range(operandSet, [](Value v) { return v.getType(); }));
136 Block *block =
137 rewriter.createBlock(&clusterOp.body(), {}, operandTypes,
138 SmallVector<Location>(operandTypes.size(), loc));
139
140 // Copy operations into the body.
141 BlockAndValueMapping bvm;
142 for (auto it : llvm::zip(operands, block->getArguments()))
143 bvm.map(std::get<0>(it), std::get<1>(it));
144 rewriter.setInsertionPointToStart(block);
145 for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
146
147 // Create `RankSpecializationClusterYieldOp`.
148 auto mappedResults = llvm::to_vector<16>(
149 llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
150 rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc, mappedResults);
151
152 // Replace original ops with the new results.
153 for (auto it : llvm::zip(results, clusterOp.results()))
154 bvm.map(std::get<0>(it), std::get<1>(it));
155 for (Operation *it : cluster) {
156 if (it->getUses().empty()) {
157 rewriter.eraseOp(it);
158 continue;
159 }
160 auto replacements = llvm::to_vector<16>(llvm::map_range(
161 it->getResults(), [&](Value v) { return bvm.lookup(v); }));
162 rewriter.replaceOp(it, replacements);
163 }
164
165 return success();
166 }
167 };
168
169 struct MergeRankSpecializationClusterOpsPattern
170 : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
171 using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
172
matchAndRewritemlir::mhlo::__anonfd3fb1340111::MergeRankSpecializationClusterOpsPattern173 LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
174 PatternRewriter &rewriter) const override {
175 auto precedingOp =
176 llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
177 op->getPrevNode());
178 if (!precedingOp) return failure();
179 Block *body = op.getBody();
180 Block *precedingBody = precedingOp.getBody();
181 auto yieldOp = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
182 op.getBody()->getTerminator());
183 auto precedingYieldOp =
184 llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
185 precedingOp.getBody()->getTerminator());
186
187 // Merge cluster operands. Consider only those operands of the second
188 // cluster that do not originate in the preceding cluster.
189 SmallVector<Value, 8> newOperands;
190 for (Value v : precedingOp.operands()) newOperands.push_back(v);
191 for (Value v : op.operands()) {
192 if (v.getDefiningOp() != precedingOp &&
193 !llvm::is_contained(precedingOp.operands(), v)) {
194 newOperands.push_back(v);
195 }
196 }
197
198 // Merge cluster results. Consider only those results of the preceding
199 // cluster that are not exclusively used as operands to the second cluster.
200 SmallVector<Value, 8> newUnmappedResults;
201 for (auto it :
202 llvm::zip(precedingOp.results(), precedingYieldOp.results())) {
203 Value result, innerResult;
204 std::tie(result, innerResult) = it;
205 if (!llvm::all_of(result.getUsers(),
206 [&](Operation *user) { return user == op; })) {
207 newUnmappedResults.push_back(innerResult);
208 }
209 }
210 for (Value v : yieldOp.results()) newUnmappedResults.push_back(v);
211
212 // Create merged cluster op.
213 rewriter.setInsertionPoint(precedingOp);
214 auto loc = op.getLoc();
215 auto resultTypes = llvm::to_vector<16>(llvm::map_range(
216 newUnmappedResults, [](Value v) { return v.getType(); }));
217 auto newOp = rewriter.create<chlo::RankSpecializationClusterOp>(
218 loc, resultTypes, newOperands);
219 auto operandTypes = llvm::to_vector<16>(
220 llvm::map_range(newOperands, [](Value v) { return v.getType(); }));
221 Block *newBody =
222 rewriter.createBlock(&newOp.body(), {}, operandTypes,
223 SmallVector<Location>(operandTypes.size(), loc));
224 rewriter.setInsertionPointToStart(newBody);
225
226 // Map operands and copy operations of the preceding cluster into the new
227 // body.
228 BlockAndValueMapping bvm;
229 for (const auto &it : llvm::enumerate(precedingBody->getArguments()))
230 bvm.map(it.value(), newBody->getArgument(it.index()));
231 for (Operation &nestedOp : precedingBody->without_terminator())
232 rewriter.clone(nestedOp, bvm);
233
234 // Map operands and copy operations of the second cluster. If they result
235 // from the preceeding cluster, we can simply map the corresponding value
236 // internally.
237 for (auto it : llvm::zip(body->getArguments(), op.operands())) {
238 Value blockArg, operand;
239 std::tie(blockArg, operand) = it;
240 if (operand.getDefiningOp() == precedingOp) {
241 auto where = llvm::find(precedingOp.results(), operand);
242 assert(where.getBase() != nullptr && "expected to find ");
243 bvm.map(blockArg,
244 bvm.lookup(precedingYieldOp.getOperand(where.getIndex())));
245 } else {
246 auto where = llvm::find(newOp.operands(), operand);
247 bvm.map(blockArg, newBody->getArgument(where.getIndex()));
248 }
249 }
250 for (Operation &nestedOp : body->without_terminator()) {
251 rewriter.clone(nestedOp, bvm);
252 }
253
254 // Yield inner results.
255 rewriter.create<chlo::RankSpecializationClusterYieldOp>(
256 loc,
257 llvm::to_vector<16>(llvm::map_range(newUnmappedResults, [&](Value v) {
258 return bvm.lookupOrDefault(v);
259 })));
260
261 // Replace the two cluster ops with the new corresponding results.
262 SmallVector<Value, 8> precedingOpReplacements;
263 int64_t i = 0;
264 for (Value result : precedingOp.results()) {
265 Value replacement = nullptr;
266 if (!llvm::all_of(result.getUsers(),
267 [&](Operation *user) { return user == op; })) {
268 replacement = newOp->getResult(i++);
269 }
270 precedingOpReplacements.push_back(replacement);
271 }
272 ValueRange opReplacements = newOp.results().take_back(op.getNumResults());
273 rewriter.replaceOp(op, opReplacements);
274 rewriter.replaceOp(precedingOp, precedingOpReplacements);
275
276 return success();
277 }
278 };
279
280 struct RankSpecializationClusterPass
281 : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
getDependentDialectsmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPass282 void getDependentDialects(DialectRegistry ®istry) const override {
283 registry.insert<mhlo::MhloDialect, chlo::ChloDialect>();
284 }
285
runOnOperationmlir::mhlo::__anonfd3fb1340111::RankSpecializationClusterPass286 void runOnOperation() override {
287 MLIRContext *ctx = &getContext();
288 RewritePatternSet patterns(ctx);
289 mhlo::populateRankSpecializationClusterPatterns(ctx, &patterns);
290 if (failed(applyPatternsAndFoldGreedily(getOperation(),
291 std::move(patterns)))) {
292 return signalPassFailure();
293 }
294 }
295 };
296
297 /// Lower rank specialization cluster to SCF.
298
isScalarTensorType(Type ty)299 bool isScalarTensorType(Type ty) {
300 auto rankedTy = ty.dyn_cast<RankedTensorType>();
301 return rankedTy && rankedTy.getRank() == 0;
302 }
303
isScalarShapeType(Type ty)304 bool isScalarShapeType(Type ty) {
305 return ty.cast<RankedTensorType>().getDimSize(0) == 0;
306 }
307
deriveRankedTensorTypes(Type ty,int64_t rank)308 Type deriveRankedTensorTypes(Type ty, int64_t rank) {
309 auto tensorTy = ty.dyn_cast<TensorType>();
310 if (!tensorTy) return ty;
311 SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
312 return RankedTensorType::get(shape, tensorTy.getElementType());
313 }
314
deriveUnrankedTensorTypes(Type ty)315 Type deriveUnrankedTensorTypes(Type ty) {
316 if (auto rankedTy = ty.dyn_cast<RankedTensorType>())
317 return UnrankedTensorType::get(rankedTy.getElementType());
318 return ty;
319 }
320
materializeRankedOperations(OpBuilder & b,Location loc,BlockAndValueMapping & bvm,chlo::RankSpecializationClusterOp op)321 SmallVector<Value, 8> materializeRankedOperations(
322 OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
323 chlo::RankSpecializationClusterOp op) {
324 // Create ranked operations.
325 for (Operation &nestedOp : op.getBody()->without_terminator()) {
326 auto mappedOperands = llvm::to_vector<4>(llvm::map_range(
327 nestedOp.getOperands(), [&](Value v) { return bvm.lookup(v); }));
328 int64_t targetRank = 0;
329 for (Value v : mappedOperands) {
330 targetRank =
331 std::max(targetRank, v.getType().cast<RankedTensorType>().getRank());
332 }
333 auto rankedResultTypes = llvm::to_vector<2>(
334 llvm::map_range(nestedOp.getResultTypes(), [targetRank](Type ty) {
335 return deriveRankedTensorTypes(ty, targetRank);
336 }));
337 OperationState rankedOpState(loc, nestedOp.getName().getStringRef(),
338 mappedOperands, rankedResultTypes,
339 nestedOp.getAttrs());
340 Operation *rankedOp = b.create(rankedOpState);
341 for (auto it : llvm::zip(nestedOp.getResults(), rankedOp->getResults()))
342 bvm.map(std::get<0>(it), std::get<1>(it));
343 }
344
345 // Collect ranked results.
346 auto yieldOp = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
347 op.getBody()->getTerminator());
348 return llvm::to_vector<8>(llvm::map_range(
349 yieldOp.results(), [&](Value v) { return bvm.lookup(v); }));
350 }
351
materializeFinalReshape(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange unshapedResults)352 SmallVector<Value, 8> materializeFinalReshape(
353 PatternRewriter &rewriter, Location loc,
354 chlo::RankSpecializationClusterOp op, ValueRange unshapedResults) {
355 auto yieldOp = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
356 op.getBody()->getTerminator());
357 assert(unshapedResults.size() == 1 && yieldOp.results().size() == 1 &&
358 "Currently, rank specialization supports only one result.");
359
360 // Reify result shape.
361 Operation *lastOpBeforeShapeReification = op->getPrevNode();
362 SmallVector<Value, 1> resultShape;
363 Value originalResult = yieldOp.results().front();
364 auto originalResultIface =
365 llvm::cast<InferShapedTypeOpInterface>(originalResult.getDefiningOp());
366 if (failed(originalResultIface.reifyReturnTypeShapes(
367 rewriter, originalResultIface->getOperands(), resultShape))) {
368 return {};
369 }
370
371 // Materialize final reshape.
372 Value unshapedResult = unshapedResults.front();
373 Value result = rewriter.create<mhlo::DynamicReshapeOp>(
374 loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult,
375 resultShape.front());
376
377 // Reify shapes until they are independent of operations in the original
378 // cluster.
379 {
380 Operation *it = resultShape.front().getDefiningOp();
381 while (it != nullptr && it != lastOpBeforeShapeReification) {
382 bool advanced = false;
383 if (auto shapeOfOp = llvm::dyn_cast<shape::ShapeOfOp>(it)) {
384 Operation *def = shapeOfOp.getArg().getDefiningOp();
385 if (def && def->getBlock() == op.getBody()) {
386 // Resolve `shape_of` op because it still depends on operation in the
387 // original cluster.
388 OpBuilder::InsertionGuard guard(rewriter);
389 rewriter.setInsertionPoint(shapeOfOp);
390 SmallVector<Value, 1> tmpShape;
391 auto iface = llvm::cast<InferShapedTypeOpInterface>(def);
392 if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(),
393 tmpShape)))
394 return {};
395 rewriter.replaceOp(shapeOfOp, tmpShape.front());
396
397 // Continue, including the newly created operations.
398 it = tmpShape.front().getDefiningOp();
399 advanced = true;
400 }
401 }
402
403 // Skip op, otherwise.
404 if (!advanced) it = it->getPrevNode();
405 }
406 }
407
408 // Replace all remaining uses of the original cluster's block args.
409 for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) {
410 Value operand, barg;
411 std::tie(operand, barg) = it;
412 barg.replaceUsesWithIf(operand, [&](OpOperand &operand) {
413 return operand.getOwner()->getBlock() != op.getBody();
414 });
415 }
416
417 return {result};
418 }
419
materializeFlatShape(OpBuilder & b,Location loc,ValueRange sameShapes)420 Value materializeFlatShape(OpBuilder &b, Location loc, ValueRange sameShapes) {
421 assert(!sameShapes.empty() && "Expected at least one shape.");
422 Value shape = sameShapes.size() == 1
423 ? sameShapes.front()
424 : b.create<shape::AnyOp>(loc, sameShapes.front().getType(),
425 sameShapes);
426 return b.create<tensor::FromElementsOp>(
427 loc,
428 b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).getResult());
429 }
430
materializeScalarRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,ValueRange nonScalarsOfSameShape,function_ref<void (OpBuilder &,Location)> elseBuilderFn)431 Value materializeScalarRankSpecializationCase(
432 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
433 const SmallVector<Value, 8> &shapes, ValueRange nonScalarsOfSameShape,
434 function_ref<void(OpBuilder &, Location)> elseBuilderFn) {
435 // Materialize predicate: All operands are scalars, except the expected
436 // non-scalars.
437 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
438 Value allOthersAreScalar;
439 for (auto it : llvm::zip(op.operands(), shapes)) {
440 Value operand, shape;
441 std::tie(operand, shape) = it;
442 if (llvm::is_contained(nonScalarsOfSameShape, operand) ||
443 isScalarTensorType(operand.getType())) {
444 continue;
445 }
446 auto literal = b.create<arith::CmpIOp>(
447 loc, arith::CmpIPredicate::eq,
448 b.create<shape::NumElementsOp>(loc, shape), one);
449 allOthersAreScalar =
450 allOthersAreScalar
451 ? b.create<mlir::arith::AndIOp>(loc, allOthersAreScalar, literal)
452 .getResult()
453 : literal.getResult();
454 }
455
456 auto ifOp = b.create<scf::IfOp>(
457 loc, op->getResultTypes(), allOthersAreScalar,
458 [&](OpBuilder &b, Location loc) {
459 // Compute flat non-scalar shape.
460 SmallVector<Value, 4> nonScalarShapes;
461 for (auto it : llvm::zip(op.operands(), shapes)) {
462 Value operand, shape;
463 std::tie(operand, shape) = it;
464 if (llvm::is_contained(nonScalarsOfSameShape, operand))
465 nonScalarShapes.push_back(shape);
466 }
467 Value flatShape = materializeFlatShape(b, loc, nonScalarShapes);
468
469 // Derive ranked operands.
470 auto rankedOperands =
471 llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
472 if (isScalarTensorType(v.getType())) return v;
473 if (!llvm::is_contained(nonScalarsOfSameShape, v)) {
474 return b
475 .create<mhlo::ReshapeOp>(
476 loc, deriveRankedTensorTypes(v.getType(), /*rank=*/0),
477 v)
478 .getResult();
479 }
480 return b
481 .create<mhlo::DynamicReshapeOp>(
482 loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
483 flatShape)
484 .getResult();
485 }));
486
487 // Materialize ranked variants for the element-wise operations.
488 BlockAndValueMapping bvm;
489 for (auto it : llvm::zip(op.getBody()->getArguments(), rankedOperands))
490 bvm.map(std::get<0>(it), std::get<1>(it));
491 Value unshapedResult =
492 materializeRankedOperations(b, loc, bvm, op).front();
493
494 // Return as unranked tensor for compatibility with the other cases.
495 b.create<scf::YieldOp>(
496 loc, b.create<tensor::CastOp>(
497 loc, deriveUnrankedTensorTypes(unshapedResult.getType()),
498 unshapedResult)
499 .getDest());
500 },
501 elseBuilderFn);
502
503 return ifOp.getResults().front();
504 }
505
materializeEqualShapesRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,function_ref<void (OpBuilder &,Location)> elseBuilderFn)506 Value materializeEqualShapesRankSpecializationCase(
507 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
508 const SmallVector<Value, 8> &shapes,
509 function_ref<void(OpBuilder &, Location)> elseBuilderFn) {
510 // Materialize all shapes equal predicate.
511 Value allShapesEqOrScalar;
512 auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range(
513 shapes, [](Value v) { return !isScalarShapeType(v.getType()); }));
514 assert(
515 nonScalarShapes.size() >= 2 &&
516 "Equal shapes strategy requires at least two non-scalar operand shapes.");
517 for (Value s : llvm::drop_begin(nonScalarShapes)) {
518 auto literal = b.create<shape::ShapeEqOp>(loc, nonScalarShapes.front(), s);
519 allShapesEqOrScalar =
520 allShapesEqOrScalar
521 ? b.create<mlir::arith::AndIOp>(loc, allShapesEqOrScalar, literal)
522 .getResult()
523 : literal;
524 }
525
526 auto ifOp = b.create<scf::IfOp>(
527 loc, op->getResultTypes(), allShapesEqOrScalar,
528 [&](OpBuilder &b, Location loc) {
529 // Flatten non-scalar operands.
530 Value flatShape = materializeFlatShape(b, loc, nonScalarShapes);
531 auto flatOperands =
532 llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
533 if (isScalarTensorType(v.getType())) return v;
534 return b
535 .create<mhlo::DynamicReshapeOp>(
536 loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
537 flatShape)
538 .result();
539 }));
540
541 // Materialize ranked variants for the element-wise operations.
542 BlockAndValueMapping bvm;
543 for (auto it : llvm::zip(op.getBody()->getArguments(), flatOperands))
544 bvm.map(std::get<0>(it), std::get<1>(it));
545 Value unshapedResult =
546 materializeRankedOperations(b, loc, bvm, op).front();
547
548 // Return as unranked tensor for compatibility with the other cases.
549 b.create<scf::YieldOp>(
550 loc, b.create<tensor::CastOp>(
551 loc, deriveUnrankedTensorTypes(unshapedResult.getType()),
552 unshapedResult)
553 .getDest());
554 },
555 elseBuilderFn);
556
557 return ifOp.getResults().front();
558 }
559
materializeTargetRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t targetRank)560 Value materializeTargetRankSpecializationCase(
561 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
562 const SmallVector<Value, 8> &shapes, int64_t targetRank) {
563 // Reshape unranked operands to match the target rank.
564 RankedTensorType extentTensorTy =
565 shape::getExtentTensorType(b.getContext(), targetRank);
566 Value allOnesShape = b.create<shape::ConstShapeOp>(
567 loc, extentTensorTy,
568 mlir::DenseIntElementsAttr::get(extentTensorTy,
569 SmallVector<int64_t, 6>(targetRank, 1)));
570 SmallVector<Value, 8> rankedOperands;
571 for (auto it : llvm::zip(op.operands(), shapes)) {
572 Value operand, shape;
573 std::tie(operand, shape) = it;
574 if (operand.getType().isa<RankedTensorType>()) {
575 rankedOperands.push_back(operand);
576 continue;
577 }
578 Value rankedShape = b.create<tensor::CastOp>(
579 loc, extentTensorTy,
580 b.create<shape::BroadcastOp>(loc,
581 shape::getExtentTensorType(b.getContext()),
582 shape, allOnesShape,
583 /*error=*/nullptr));
584 rankedOperands.push_back(b.create<mhlo::DynamicReshapeOp>(
585 loc, deriveRankedTensorTypes(operand.getType(), targetRank), operand,
586 rankedShape));
587 }
588
589 // Materialize ranked versions of the element-wise operations.
590 BlockAndValueMapping bvm;
591 for (auto it : llvm::zip(op.body().front().getArguments(), rankedOperands))
592 bvm.map(std::get<0>(it), std::get<1>(it));
593
594 // Return as unranked for compatibility with other target ranks.
595 auto unshapedResult = materializeRankedOperations(b, loc, bvm, op).front();
596 return b.create<tensor::CastOp>(
597 loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult);
598 }
599
recusivelyMaterializeTargetRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,Value maxRank,int64_t minTargetRank,int64_t maxTargetRank)600 Value recusivelyMaterializeTargetRankSpecializationCases(
601 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
602 const SmallVector<Value, 8> &shapes, Value maxRank, int64_t minTargetRank,
603 int64_t maxTargetRank) {
604 Value condition = b.create<arith::CmpIOp>(
605 loc, arith::CmpIPredicate::ule, maxRank,
606 b.create<arith::ConstantIndexOp>(loc, minTargetRank));
607
608 // If only a unique target rank is left, we can lower to an assert instead
609 // of the usual if operation.
610 if (minTargetRank == maxTargetRank) {
611 b.create<cf::AssertOp>(
612 loc, condition,
613 "Input for dynamic binary or n-ary op lowering was of "
614 "a rank greater than " +
615 std::to_string(maxTargetRank));
616 return materializeTargetRankSpecializationCase(b, loc, op, shapes,
617 minTargetRank);
618 }
619
620 // Materialize IR for the smallest considered target rank.
621 auto ifOp = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
622 /*withElseRegion=*/true);
623 auto thenBuilder = ifOp.getThenBodyBuilder();
624 thenBuilder.create<scf::YieldOp>(
625 loc, materializeTargetRankSpecializationCase(thenBuilder, loc, op, shapes,
626 minTargetRank));
627
628 // Recurse for all remaining target ranks.
629 auto elseBuilder = ifOp.getElseBodyBuilder();
630 elseBuilder.create<scf::YieldOp>(
631 loc, recusivelyMaterializeTargetRankSpecializationCases(
632 elseBuilder, loc, op, shapes, maxRank, minTargetRank + 1,
633 maxTargetRank));
634
635 return ifOp.getResults().front();
636 }
637
materializeGenericRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t maxTargetRank)638 Value materializeGenericRankSpecializationCases(
639 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
640 const SmallVector<Value, 8> &shapes, int64_t maxTargetRank) {
641 // Get the minimum broadcast shapes of the operands.
642 auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range(
643 shapes, [](Value v) { return !isScalarShapeType(v.getType()); }));
644 auto minBcastShapesOp = b.create<chlo::MinimumBroadcastShapesOp>(
645 loc,
646 SmallVector<Type, 8>(nonScalarShapes.size(),
647 shape::getExtentTensorType(b.getContext())),
648 nonScalarShapes);
649
650 // Find the maximum rank among the reduced operand shapes.
651 Value maxRank;
652 for (Value shape : minBcastShapesOp.results()) {
653 Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
654 if (!maxRank) {
655 maxRank = rank;
656 } else {
657 maxRank = b.create<mlir::arith::SelectOp>(
658 loc,
659 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, maxRank,
660 rank),
661 maxRank, rank);
662 }
663 }
664
665 // Collect reduced shapes.
666 SmallVector<Value, 8> reducedShapes;
667 auto it = minBcastShapesOp.result_begin();
668 for (Value s : shapes) {
669 if (isScalarShapeType(s.getType())) {
670 reducedShapes.push_back(s);
671 } else {
672 reducedShapes.push_back(*it++);
673 }
674 }
675
676 // Materialize rank specialization for ranks 1, ...
677 return recusivelyMaterializeTargetRankSpecializationCases(
678 b, loc, op, reducedShapes, maxRank, /*minTargetRank=*/1, maxTargetRank);
679 }
680
materializeDefaultRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t maxTargetRank)681 Value materializeDefaultRankSpecializationCases(
682 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
683 const SmallVector<Value, 8> &shapes, int64_t maxTargetRank) {
684 return materializeEqualShapesRankSpecializationCase(
685 b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
686 b.create<scf::YieldOp>(loc, materializeGenericRankSpecializationCases(
687 b, loc, op, shapes, maxTargetRank));
688 });
689 }
690
691 SmallVector<Value, 8>
materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange nonScalarsOfSameShape)692 materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
693 PatternRewriter &rewriter, Location loc,
694 chlo::RankSpecializationClusterOp op, ValueRange nonScalarsOfSameShape) {
695 // Compute flat operand shape.
696 auto nonScalarShapes =
697 llvm::to_vector<4>(llvm::map_range(nonScalarsOfSameShape, [&](Value v) {
698 return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
699 }));
700 Value flatShape = materializeFlatShape(rewriter, loc, nonScalarShapes);
701
702 // Materialize ranked variants for the element-wise operations.
703 BlockAndValueMapping bvm;
704 for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
705 Value operand;
706 Value bbArg;
707 std::tie(bbArg, operand) = it;
708 if (!isScalarTensorType(operand.getType())) {
709 assert(llvm::is_contained(nonScalarsOfSameShape, operand) &&
710 "Expected all non-scalars in the same shape equivalence class.");
711 operand = rewriter.create<mhlo::DynamicReshapeOp>(
712 loc, deriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
713 flatShape);
714 }
715 bvm.map(bbArg, operand);
716 }
717 SmallVector<Value, 8> unshapedResults =
718 materializeRankedOperations(rewriter, loc, bvm, op);
719
720 // Restore the results' expected shape.
721 Value shape = nonScalarShapes.front();
722 return llvm::to_vector<8>(llvm::map_range(unshapedResults, [&](Value v) {
723 return rewriter
724 .create<mhlo::DynamicReshapeOp>(
725 loc, deriveUnrankedTensorTypes(v.getType()), v, shape)
726 .result();
727 }));
728 }
729
materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,SmallVector<SmallVector<Value,4>,4> nonScalarEqs,int64_t maxTargetRank)730 Value materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
731 PatternRewriter &rewriter, Location loc,
732 chlo::RankSpecializationClusterOp op,
733 SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs, int64_t maxTargetRank) {
734 assert(nonScalarEqs.size() == 2 &&
735 "Expect two non-scalar equivalence classes.");
736 auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
737 return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
738 }));
739 ValueRange lhsNonScalarEqs = nonScalarEqs[0];
740 ValueRange rhsNonScalarEqs = nonScalarEqs[1];
741
742 // Materialize all the different cases.
743 Value unshapedResult = materializeScalarRankSpecializationCase(
744 rewriter, loc, op, shapes, rhsNonScalarEqs,
745 [&](OpBuilder &b, Location loc) {
746 b.create<scf::YieldOp>(
747 loc, materializeScalarRankSpecializationCase(
748 b, loc, op, shapes, lhsNonScalarEqs,
749 [&](OpBuilder &b, Location loc) {
750 b.create<scf::YieldOp>(
751 loc, materializeDefaultRankSpecializationCases(
752 b, loc, op, shapes, maxTargetRank));
753 }));
754 });
755
756 // Materialize final reshape once and for all rank specialization cases.
757 return materializeFinalReshape(rewriter, loc, op, unshapedResult).front();
758 }
759
760 // Materialize rank generic rank specialization.
materializeDefaultRankSpecialization(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,int64_t maxTargetRank)761 Value materializeDefaultRankSpecialization(PatternRewriter &rewriter,
762 Location loc,
763 chlo::RankSpecializationClusterOp op,
764 int64_t maxTargetRank) {
765 auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
766 return rewriter.create<shape::ShapeOfOp>(loc, v).getResult();
767 }));
768
769 // Materialize all the different cases.
770 Value unshapedResult = materializeDefaultRankSpecializationCases(
771 rewriter, loc, op, shapes, maxTargetRank);
772
773 // Materialize final reshape once and for all rank specialization cases.
774 return materializeFinalReshape(rewriter, loc, op, unshapedResult).front();
775 }
776
777 // This is a very limited form of shape inference. It is correct but incomplete.
findNonScalarShapeEquivalences(chlo::RankSpecializationClusterOp op)778 SmallVector<SmallVector<Value, 4>, 4> findNonScalarShapeEquivalences(
779 chlo::RankSpecializationClusterOp op) {
780 llvm::EquivalenceClasses<Value> eqs;
781
782 // Bridge the equivalences between operands and block arguments.
783 for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
784 eqs.unionSets(std::get<0>(it), std::get<1>(it));
785
786 // Find equalities through `SameOperandsAndResultShape` trait.
787 auto unionSets = [&](ValueRange vs) {
788 if (vs.empty()) return;
789 Value repr = vs.front();
790 for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
791 };
792 for (Operation &nestedOp : op.getBody()->without_terminator()) {
793 if (nestedOp.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
794 unionSets(nestedOp.getOperands());
795 unionSets(nestedOp.getResults());
796 if (!nestedOp.getOperands().empty() && !nestedOp.getResults().empty())
797 eqs.unionSets(nestedOp.getResult(0), nestedOp.getOperand(0));
798 }
799 }
800
801 // Find shape equalities through surrounding constraints.
802 if (auto assumingOp = op->getParentOfType<shape::AssumingOp>()) {
803 SmallVector<Operation *, 8> queue;
804 auto appendIfNotNull = [&](Operation *op) {
805 if (op != nullptr) queue.push_back(op);
806 };
807 appendIfNotNull(assumingOp.getWitness().getDefiningOp());
808 while (!queue.empty()) {
809 Operation *it = queue.pop_back_val();
810 if (auto assumingAllOp = llvm::dyn_cast<shape::AssumingAllOp>(it)) {
811 for (Value v : assumingAllOp.getInputs())
812 appendIfNotNull(v.getDefiningOp());
813 } else if (auto cstrEqOp = llvm::dyn_cast<shape::CstrEqOp>(it)) {
814 Value refArg;
815 for (Value v : cstrEqOp.getShapes()) {
816 if (auto shapeOfOp =
817 dyn_cast_or_null<shape::ShapeOfOp>(v.getDefiningOp())) {
818 if (!refArg) {
819 refArg = shapeOfOp.getArg();
820 } else {
821 eqs.unionSets(refArg, shapeOfOp.getArg());
822 }
823 }
824 }
825 }
826 }
827 }
828
829 // Find equalities through special knowledge of ops.
830 // TODO(frgossen): Remove this when these shape equalities can be inferred
831 // from surrounding shape constraints.
832 for (Operation &nestedOp : op.getBody()->without_terminator()) {
833 if (auto selectOp = llvm::dyn_cast<mhlo::SelectOp>(nestedOp)) {
834 unionSets(
835 {selectOp.on_true(), selectOp.on_false(), selectOp.getResult()});
836 } else if (auto clampOp = llvm::dyn_cast<mhlo::ClampOp>(nestedOp)) {
837 unionSets({clampOp.operand(), clampOp.getResult()});
838 }
839 }
840
841 // Convert to a list-like equivalence class representation.
842 SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs;
843 for (Value v : op.operands()) {
844 if (isScalarTensorType(v.getType())) continue;
845 bool inserted = false;
846 for (auto &eqClass : nonScalarEqs) {
847 if (eqs.isEquivalent(eqClass.front(), v)) {
848 eqClass.push_back(v);
849 inserted = true;
850 break;
851 }
852 }
853 if (!inserted) nonScalarEqs.push_back(SmallVector<Value, 4>({v}));
854 }
855
856 return nonScalarEqs;
857 }
858
859 struct LowerRankSpecializationClusterPattern
860 : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
LowerRankSpecializationClusterPatternmlir::mhlo::__anonfd3fb1340111::LowerRankSpecializationClusterPattern861 LowerRankSpecializationClusterPattern(MLIRContext *ctx, int64_t maxTargetRank)
862 : OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
863 maxTargetRank(maxTargetRank) {}
864
matchAndRewritemlir::mhlo::__anonfd3fb1340111::LowerRankSpecializationClusterPattern865 LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
866 PatternRewriter &rewriter) const override {
867 // Restoring the result shape currently relies on all operands being used
868 // for a single result. The result shape is then the broadcasted shape of
869 // all operands.
870 if (op.getNumResults() != 1) return failure();
871
872 // If there is only a single non-scalar shape equivalence class, we can
873 // flatten that operands completely.
874 SmallVector<SmallVector<Value, 4>, 4> nonScalarEqs =
875 findNonScalarShapeEquivalences(op);
876 Location loc = op.getLoc();
877 if (nonScalarEqs.size() == 1) {
878 rewriter.replaceOp(
879 op,
880 materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
881 rewriter, loc, op, nonScalarEqs.front()));
882 return success();
883 }
884
885 // If there are exactly two non-scalar shape equivalence classes, we can
886 // consider two extra cases: If either of the operand classes turns out to
887 // be all-scalars at runtime, we can, again, flatten all operands.
888 if (nonScalarEqs.size() == 2) {
889 rewriter.replaceOp(
890 op,
891 materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
892 rewriter, loc, op, nonScalarEqs, maxTargetRank));
893 return success();
894 }
895
896 // For all other cases, reshape the operands to match in rank, apply the
897 // operation, and restore the expected shape.
898 rewriter.replaceOp(op, materializeDefaultRankSpecialization(
899 rewriter, loc, op, maxTargetRank));
900 return success();
901 }
902
903 private:
904 int64_t maxTargetRank;
905 };
906
907 struct RankSpecializationToSCFPass
908 : public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
RankSpecializationToSCFPassmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass909 explicit RankSpecializationToSCFPass(int64_t maxTargetRank)
910 : RankSpecializationToSCFPassBase<
911 RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
912 this->max_target_rank_ = maxTargetRank;
913 }
914
getDependentDialectsmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass915 void getDependentDialects(DialectRegistry ®istry) const override {
916 registry.insert<mhlo::MhloDialect, chlo::ChloDialect, func::FuncDialect,
917 shape::ShapeDialect, scf::SCFDialect>();
918 }
919
runOnOperationmlir::mhlo::__anonfd3fb1340111::RankSpecializationToSCFPass920 void runOnOperation() override {
921 MLIRContext *ctx = &getContext();
922 RewritePatternSet patterns(ctx);
923 populateRankSpecializationToSCFPatterns(ctx, &patterns,
924 this->max_target_rank_);
925 if (failed(applyPatternsAndFoldGreedily(getOperation(),
926 std::move(patterns)))) {
927 return signalPassFailure();
928 }
929 }
930 };
931
932 } // namespace
933
populateRankSpecializationClusterPatterns(MLIRContext * context,RewritePatternSet * patterns)934 void populateRankSpecializationClusterPatterns(MLIRContext *context,
935 RewritePatternSet *patterns) {
936 patterns->add<MergeRankSpecializationClusterOpsPattern,
937 RankSpecializationClusterPattern>(context);
938 }
939
populateRankSpecializationToSCFPatterns(MLIRContext * context,RewritePatternSet * patterns,int64_t maxTargetRank)940 void populateRankSpecializationToSCFPatterns(MLIRContext *context,
941 RewritePatternSet *patterns,
942 int64_t maxTargetRank) {
943 patterns->add<LowerRankSpecializationClusterPattern>(context, maxTargetRank);
944 shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
945 shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context);
946 shape::AnyOp::getCanonicalizationPatterns(*patterns, context);
947 }
948
949 std::unique_ptr<OperationPass<func::FuncOp>>
createRankSpecializationClusterPass()950 createRankSpecializationClusterPass() {
951 return std::make_unique<RankSpecializationClusterPass>();
952 }
953
createRankSpecializationToSCFPass(int64_t maxTargetRank)954 std::unique_ptr<OperationPass<func::FuncOp>> createRankSpecializationToSCFPass(
955 int64_t maxTargetRank) {
956 return std::make_unique<RankSpecializationToSCFPass>(maxTargetRank);
957 }
958
959 } // namespace mhlo
960 } // namespace mlir
961