1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "llvm/ADT/DenseMapInfo.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37 
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41 
42 // To avoid duplicate broadcasts, we collect all the intended broadcasts ahead
43 // of realizing any broadcasts in the IR. These are broadcasted versions of
44 // values that we are interested in, and they are uniquely characterized by a
45 // `BroadcastIntent` value.
46 struct BroadcastIntent {
47   RankedTensorType resultType;
48   Value targetValue;
49   Value outputDimensions;
50   Attribute broadcastDimensions;
operator ==mlir::mhlo::__anond615a9400111::BroadcastIntent51   bool operator==(BroadcastIntent rhs) const {
52     return resultType == rhs.resultType && targetValue == rhs.targetValue &&
53            outputDimensions == rhs.outputDimensions &&
54            broadcastDimensions == rhs.broadcastDimensions;
55   }
operator !=mlir::mhlo::__anond615a9400111::BroadcastIntent56   bool operator!=(BroadcastIntent rhs) const { return !(*this == rhs); }
57 };
58 
59 }  // namespace
60 }  // namespace mhlo
61 }  // namespace mlir
62 
63 namespace llvm {
64 
65 template <>
66 struct DenseMapInfo<mlir::mhlo::BroadcastIntent> {
getEmptyKeyllvm::DenseMapInfo67   static mlir::mhlo::BroadcastIntent getEmptyKey() {
68     return {DenseMapInfo<mlir::RankedTensorType>::getEmptyKey(),
69             DenseMapInfo<mlir::Value>::getEmptyKey(),
70             DenseMapInfo<mlir::Value>::getEmptyKey(),
71             DenseMapInfo<mlir::Attribute>::getEmptyKey()};
72   }
getTombstoneKeyllvm::DenseMapInfo73   static mlir::mhlo::BroadcastIntent getTombstoneKey() {
74     return {DenseMapInfo<mlir::RankedTensorType>::getTombstoneKey(),
75             DenseMapInfo<mlir::Value>::getTombstoneKey(),
76             DenseMapInfo<mlir::Value>::getTombstoneKey(),
77             DenseMapInfo<mlir::Attribute>::getTombstoneKey()};
78   }
getHashValuellvm::DenseMapInfo79   static unsigned getHashValue(const mlir::mhlo::BroadcastIntent &intent) {
80     return hash_combine(
81         DenseMapInfo<mlir::RankedTensorType>::getHashValue(intent.resultType),
82         DenseMapInfo<mlir::Value>::getHashValue(intent.targetValue),
83         DenseMapInfo<mlir::Value>::getHashValue(intent.outputDimensions),
84         DenseMapInfo<mlir::Attribute>::getHashValue(
85             intent.broadcastDimensions));
86   }
isEqualllvm::DenseMapInfo87   static bool isEqual(const mlir::mhlo::BroadcastIntent &lhs,
88                       const mlir::mhlo::BroadcastIntent &rhs) {
89     return lhs == rhs;
90   }
91 };
92 
93 }  // namespace llvm
94 
95 namespace mlir {
96 namespace mhlo {
97 namespace {
98 
allowsForElementwiseBroadcastPropagation(Operation * op)99 bool allowsForElementwiseBroadcastPropagation(Operation *op) {
100   if (op && op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() &&
101       op->hasTrait<mlir::OpTrait::Elementwise>() && op->getNumResults() == 1) {
102     return true;
103   }
104   if (op && op->hasTrait<mlir::mhlo::OpTrait::BroadcastingElementwise>() &&
105       op->getNumResults() == 1) {
106     return true;
107   }
108   return false;
109 }
110 
allowsForBroadcastPropagation(Operation * op)111 bool allowsForBroadcastPropagation(Operation *op) {
112   return llvm::isa_and_nonnull<DynamicBroadcastInDimOp>(op) ||
113          allowsForElementwiseBroadcastPropagation(op);
114 }
115 
composeBroadcastDimensionsAttr(OpBuilder & builder,DenseIntElementsAttr a,DenseIntElementsAttr b)116 DenseIntElementsAttr composeBroadcastDimensionsAttr(OpBuilder &builder,
117                                                     DenseIntElementsAttr a,
118                                                     DenseIntElementsAttr b) {
119   SmallVector<int64_t> bVec =
120       llvm::to_vector(llvm::map_range(b, [](const APInt &it) {
121         return static_cast<int64_t>(it.getLimitedValue());
122       }));
123   SmallVector<int64_t> composedVec = llvm::to_vector(llvm::map_range(
124       a, [bVec](const APInt &it) { return bVec[it.getLimitedValue()]; }));
125   return builder.getI64TensorAttr(composedVec);
126 }
127 
128 // Find all the broadcast intents and their dependencies. Start analyzing from
129 // the root an collect all broadcast intents that can help broadcast propagation
130 // from there.
findBroadcastIntents(DynamicBroadcastInDimOp root,Block * parentBlock,BroadcastIntent & rootBcastIntent,SmallVector<BroadcastIntent> & bcastIntents,DenseMap<BroadcastIntent,SmallVector<BroadcastIntent>> & bcastIntentDependencies)131 void findBroadcastIntents(
132     DynamicBroadcastInDimOp root, Block *parentBlock,
133     BroadcastIntent &rootBcastIntent,
134     SmallVector<BroadcastIntent> &bcastIntents,
135     DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
136         &bcastIntentDependencies) {
137   OpBuilder builder(root.getContext());
138 
139   // Use the result vector of broadcast intents as a worklist. The set of
140   // broadcast intents helps to ensure their uniqueness.
141   DenseSet<BroadcastIntent> bcastIntentsSet;
142   auto addToWorklistIfNew = [&](BroadcastIntent bcastIntent) {
143     if (!bcastIntentsSet.count(bcastIntent)) {
144       bcastIntentsSet.insert(bcastIntent);
145       bcastIntents.push_back(bcastIntent);
146     }
147   };
148 
149   // Derive the broadcast intent associated with the root broadcast operation.
150   // Add it to the worklist to seed the analysis.
151   rootBcastIntent = {root.getResult().getType().cast<RankedTensorType>(),
152                      root.operand(), root.output_dimensions(),
153                      root.broadcast_dimensions()};
154   addToWorklistIfNew(rootBcastIntent);
155 
156   // We use result vector of broadcast intents as a worklist, the first `i`
157   // intents of which have been processed.
158   for (int64_t i = 0; i < static_cast<int64_t>(bcastIntents.size()); ++i) {
159     BroadcastIntent it = bcastIntents[i];
160     Operation *producerOp = it.targetValue.getDefiningOp();
161 
162     // We can propagate broadcasts over (broadcasting) element-wise operations
163     // and dynamic_broadcast_in_dim ops with the restriction that they must be
164     // in the same block as they may depend on assuming regions.
165     if (!producerOp || producerOp->getBlock() != parentBlock ||
166         !allowsForBroadcastPropagation(producerOp)) {
167       continue;
168     }
169 
170     // We can skip broadcasting producers (dynamic_broadcast_in_dim ops) if we
171     // compose their broadcasting dimensions.
172     if (auto producerBcastOp =
173             llvm::dyn_cast<DynamicBroadcastInDimOp>(producerOp)) {
174       DenseIntElementsAttr composedBcastDims = composeBroadcastDimensionsAttr(
175           builder, producerBcastOp.broadcast_dimensions(),
176           it.broadcastDimensions.cast<DenseIntElementsAttr>());
177       BroadcastIntent bcastedOperandIntent = {
178           it.resultType, producerBcastOp.operand(), it.outputDimensions,
179           composedBcastDims};
180 
181       // Record dependency and "recur".
182       bcastIntentDependencies[it] = {bcastedOperandIntent};
183       addToWorklistIfNew(bcastedOperandIntent);
184       continue;
185     }
186 
187     // We can propagate broadcasts over (broadcasting) element-wise operations.
188     // Instead of broadcasting the result of such an op, we can broadcast the
189     // operands and apply the element-wise operation to them.
190     assert(allowsForElementwiseBroadcastPropagation(producerOp));
191     bcastIntentDependencies[it] = {};
192     for (auto operand : producerOp->getOperands()) {
193       auto operandTy = operand.getType().cast<RankedTensorType>();
194       auto operandBcastDims = operandTy.getRank() == 0
195                                   ? builder.getI64TensorAttr({})
196                                   : it.broadcastDimensions;
197       auto bcastedOperandTy = RankedTensorType::get(it.resultType.getShape(),
198                                                     operandTy.getElementType());
199       BroadcastIntent bcastedOperandIntent = {
200           bcastedOperandTy, operand, it.outputDimensions, operandBcastDims};
201 
202       // Record dependency and "recur".
203       bcastIntentDependencies[it].push_back(bcastedOperandIntent);
204       addToWorklistIfNew(bcastedOperandIntent);
205     }
206   }
207 }
208 
sortBroadcastIntentsInReverseTopologicalOrder(SmallVector<BroadcastIntent> & bcastIntentsVec,Block * parentBlock)209 void sortBroadcastIntentsInReverseTopologicalOrder(
210     SmallVector<BroadcastIntent> &bcastIntentsVec, Block *parentBlock) {
211   // Sort broadcast intents in reverse topological order of the producer ops. We
212   // can use the positions in the block for this. All broadcast intents outside
213   // the block (e.g. arguments) will be sorted towards the front.
214   // This ordering is independent of the output dimensions as dependencies can
215   // only occur between broadcast intents of the same output dimension.
216   std::sort(bcastIntentsVec.begin(), bcastIntentsVec.end(),
217             [parentBlock](const BroadcastIntent &a, const BroadcastIntent &b) {
218               Operation *producerOpA = a.targetValue.getDefiningOp();
219               Operation *producerOpB = b.targetValue.getDefiningOp();
220               bool aInBlock = producerOpA != nullptr &&
221                               producerOpA->getBlock() == parentBlock;
222               bool bInBlock = producerOpB != nullptr &&
223                               producerOpB->getBlock() == parentBlock;
224               if (aInBlock && bInBlock) {
225                 return producerOpA->isBeforeInBlock(producerOpB);
226               }
227               return !aInBlock && bInBlock;
228             });
229 }
230 
setInsertionPointToEarliestPointWithAllValuesAvailable(PatternRewriter & rewriter,Block * block,ValueRange values)231 void setInsertionPointToEarliestPointWithAllValuesAvailable(
232     PatternRewriter &rewriter, Block *block, ValueRange values) {
233   Operation *lastDef = nullptr;
234   for (Value v : values) {
235     Operation *def = v.getDefiningOp();
236     if (def && def->getBlock() == block) {
237       if (!lastDef || lastDef->isBeforeInBlock(def)) lastDef = def;
238     }
239   }
240   if (lastDef) {
241     rewriter.setInsertionPointAfter(lastDef);
242   } else {
243     rewriter.setInsertionPointToStart(block);
244   }
245 }
246 
realizeBroadcastIntents(SmallVector<BroadcastIntent> & sortedBcastIntents,DenseMap<BroadcastIntent,SmallVector<BroadcastIntent>> & bcastIntentDependencies,Block * parentBlock,PatternRewriter & rewriter)247 DenseMap<BroadcastIntent, Value> realizeBroadcastIntents(
248     SmallVector<BroadcastIntent> &sortedBcastIntents,
249     DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
250         &bcastIntentDependencies,
251     Block *parentBlock, PatternRewriter &rewriter) {
252   // Realize broadcast intents in order. They must be sorted so that their
253   // dependencies are realized before them.
254   DenseMap<BroadcastIntent, Value> realizations;
255   for (auto it : sortedBcastIntents) {
256     Operation *producerOp = it.targetValue.getDefiningOp();
257     assert(!realizations.count(it) && "expect unrealized broadcast intent");
258     auto deps = bcastIntentDependencies.find(it);
259 
260     // If we cannot propagate broadcasts further, materialize them as a
261     // dynamic_broadcast_in_dim op.
262     if (!producerOp || producerOp->getBlock() != parentBlock ||
263         !allowsForBroadcastPropagation(producerOp)) {
264       assert(deps == bcastIntentDependencies.end() && "expect no dependencies");
265       setInsertionPointToEarliestPointWithAllValuesAvailable(
266           rewriter, parentBlock,
267           ValueRange{it.targetValue, it.outputDimensions});
268       realizations[it] = rewriter.create<DynamicBroadcastInDimOp>(
269           it.targetValue.getLoc(), it.resultType, it.targetValue,
270           it.outputDimensions,
271           it.broadcastDimensions.cast<DenseIntElementsAttr>());
272       continue;
273     }
274 
275     // For broadcast propagation across dynamic_broadcast_in_dim ops, the
276     // broadcasted value is already materialized. Forward it.
277     if (auto producerBcastOp =
278             llvm::dyn_cast_or_null<DynamicBroadcastInDimOp>(producerOp)) {
279       assert(deps != bcastIntentDependencies.end() &&
280              deps->second.size() == 1 && "expect one dependency");
281       auto bcastedOperand = realizations.find(deps->second.front());
282       assert(bcastedOperand != realizations.end());
283       realizations[it] = Value(bcastedOperand->second);
284       continue;
285     }
286 
287     // Othwerwise, realize broadcast intent for a (broadcasting) element-wise
288     // operation based on the broadcasted operands.
289     assert(allowsForElementwiseBroadcastPropagation(producerOp) &&
290            "expect broadcast propagation over an (broadcasting) element-wise "
291            "operation");
292     assert(deps != bcastIntentDependencies.end() &&
293            deps->second.size() == producerOp->getNumOperands() &&
294            "expect one dependency per operand");
295     auto bcastedOperands = llvm::to_vector(
296         llvm::map_range(deps->second, [&](BroadcastIntent operandIntent) {
297           auto bcastedOperand = realizations.find(operandIntent);
298           assert(bcastedOperand != realizations.end() &&
299                  "expect dependencies to be realized earlier");
300           return bcastedOperand->second;
301         }));
302     setInsertionPointToEarliestPointWithAllValuesAvailable(
303         rewriter, parentBlock, bcastedOperands);
304     OperationState newProducerOpState(
305         producerOp->getLoc(), producerOp->getName().getStringRef(),
306         bcastedOperands, it.resultType, producerOp->getAttrs());
307     Operation *newProducerOp = rewriter.create(newProducerOpState);
308     assert(newProducerOp->getNumResults() == 1 && "expect exactly one result");
309     realizations[it] = newProducerOp->getResults().front();
310   }
311 
312   return realizations;
313 }
314 
transitivelyEraseUnusedSideEffectFreeOps(Operation * root,PatternRewriter & rewriter)315 void transitivelyEraseUnusedSideEffectFreeOps(Operation *root,
316                                               PatternRewriter &rewriter) {
317   // Find ops to erase.
318   SmallPtrSet<Operation *, 16> opsToEraseSet;
319   SmallVector<Operation *, 16> opsToErase;
320   SmallVector<Operation *, 16> worklist = {root};
321   while (!worklist.empty()) {
322     Operation *op = worklist.pop_back_val();
323 
324     // Erase ops only once.
325     if (opsToEraseSet.count(op)) continue;
326 
327     // Erase only operations that are unused and free of side effects.
328     if (!MemoryEffectOpInterface::hasNoEffect(op) ||
329         !llvm::all_of(op->getUsers(), [opsToEraseSet](Operation *user) {
330           return opsToEraseSet.count(user);
331         })) {
332       continue;
333     }
334 
335     // Erase and "recur".
336     opsToEraseSet.insert(op);
337     opsToErase.push_back(op);
338     for (Value operand : op->getOperands()) {
339       if (Operation *def = operand.getDefiningOp()) worklist.push_back(def);
340     }
341   }
342 
343   // Finally, erase the ops in the order of their uses.
344   for (Operation *op : opsToErase) rewriter.eraseOp(op);
345 }
346 
propagateBroadcast(DynamicBroadcastInDimOp root,Block * parentBlock,PatternRewriter & rewriter)347 LogicalResult propagateBroadcast(DynamicBroadcastInDimOp root,
348                                  Block *parentBlock,
349                                  PatternRewriter &rewriter) {
350   // We can move broadcasts up over (i) (broadcasting) element-wise operations
351   // and (i) dynamic_broadcast_in_dim ops. This way, we propagate them through
352   // the IR to perform them early. Instead of broadcasting the result of such an
353   // op, we can broadcast the operands and apply the element-wise operation to
354   // them.
355   //
356   // To avoid exponential growth of the IR, we will do this in two phases:
357   //   1) First, we collect all the unique broadcast intents. These are
358   //      broadcasted versions of values that we are interested in. They may
359   //      later be materialized as an explicit broadcast or they can be the
360   //      direct result of an operation over which a broadcast was propagated.
361   //   2) Then, we fulfill every broadcast intent in reverse topological order
362   //      to ensure that their dependencies (the broadcasted operands) are
363   //      available.
364 
365   // Find the unique broadcast intents.
366   BroadcastIntent rootBcastIntent;
367   SmallVector<BroadcastIntent> bcastIntents;
368   DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
369       bcastIntentDependencies;
370   findBroadcastIntents(root, parentBlock, rootBcastIntent, bcastIntents,
371                        bcastIntentDependencies);
372 
373   // Fail if there is nothing but the root intent, i.e. if there is nothing to
374   // rewrite here.
375   if (bcastIntents.size() <= 1) {
376     assert(bcastIntents.front() == rootBcastIntent && "expect root intent");
377     return failure();
378   }
379 
380   // Sort the broadcast intents in reverse topological order so that they can be
381   // materialized and every depency is available when needed.
382   sortBroadcastIntentsInReverseTopologicalOrder(bcastIntents, parentBlock);
383 
384   // Realize broadcast intents.
385   DenseMap<BroadcastIntent, Value> realizations = realizeBroadcastIntents(
386       bcastIntents, bcastIntentDependencies, parentBlock, rewriter);
387 
388   // Find the operations that may become redundant after replacing the root
389   // operation. This allows us to transitively erase unused side effect-free
390   // operations that result from this rewrite (after the root operation is no
391   // longer accessible).
392   SmallVector<Operation *> possiblyUnused;
393   for (auto operand : root->getOperands()) {
394     if (Operation *def = operand.getDefiningOp()) possiblyUnused.push_back(def);
395   }
396 
397   // Replace the root operation with its broadcast intent's realization.
398   rewriter.replaceOp(root, realizations[rootBcastIntent]);
399 
400   // Erase all the operations that have become redundant as a result of this
401   // rewrite.
402   for (Operation *op : possiblyUnused) {
403     transitivelyEraseUnusedSideEffectFreeOps(op, rewriter);
404   }
405 
406   return success();
407 }
408 
409 struct BroadcastPropagationPattern
410     : public OpRewritePattern<DynamicBroadcastInDimOp> {
411   using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
412 
matchAndRewritemlir::mhlo::__anond615a9400211::BroadcastPropagationPattern413   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
414                                 PatternRewriter &rewriter) const override {
415     return propagateBroadcast(op, op->getBlock(), rewriter);
416   }
417 };
418 
419 struct BroadcastPropagationPass
420     : public BroadcastPropagationPassBase<BroadcastPropagationPass> {
getDependentDialectsmlir::mhlo::__anond615a9400211::BroadcastPropagationPass421   void getDependentDialects(DialectRegistry &registry) const override {
422     registry.insert<mhlo::MhloDialect>();
423   }
424 
runOnOperationmlir::mhlo::__anond615a9400211::BroadcastPropagationPass425   void runOnOperation() override {
426     MLIRContext *ctx = &getContext();
427 
428     // Collect patterns.
429     RewritePatternSet patterns(ctx);
430     patterns.add<BroadcastPropagationPattern>(ctx);
431 
432     // Apply broadcast propagation in reverse order to start propagation at
433     // the root of broadcast chains. This avoids duplicate work.
434     GreedyRewriteConfig config;
435     config.useTopDownTraversal = false;
436 
437     if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
438                                             config))) {
439       return signalPassFailure();
440     }
441   }
442 };
443 
444 }  // namespace
445 
createBroadcastPropagationPass()446 std::unique_ptr<OperationPass<func::FuncOp>> createBroadcastPropagationPass() {
447   return std::make_unique<BroadcastPropagationPass>();
448 }
449 
450 }  // namespace mhlo
451 }  // namespace mlir
452