1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "llvm/ADT/DenseMapInfo.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/Operation.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41
42 // To avoid duplicate broadcasts, we collect all the intended broadcasts ahead
43 // of realizing any broadcasts in the IR. These are broadcasted versions of
44 // values that we are interested in, and they are uniquely characterized by a
45 // `BroadcastIntent` value.
46 struct BroadcastIntent {
47 RankedTensorType resultType;
48 Value targetValue;
49 Value outputDimensions;
50 Attribute broadcastDimensions;
operator ==mlir::mhlo::__anond615a9400111::BroadcastIntent51 bool operator==(BroadcastIntent rhs) const {
52 return resultType == rhs.resultType && targetValue == rhs.targetValue &&
53 outputDimensions == rhs.outputDimensions &&
54 broadcastDimensions == rhs.broadcastDimensions;
55 }
operator !=mlir::mhlo::__anond615a9400111::BroadcastIntent56 bool operator!=(BroadcastIntent rhs) const { return !(*this == rhs); }
57 };
58
59 } // namespace
60 } // namespace mhlo
61 } // namespace mlir
62
63 namespace llvm {
64
65 template <>
66 struct DenseMapInfo<mlir::mhlo::BroadcastIntent> {
getEmptyKeyllvm::DenseMapInfo67 static mlir::mhlo::BroadcastIntent getEmptyKey() {
68 return {DenseMapInfo<mlir::RankedTensorType>::getEmptyKey(),
69 DenseMapInfo<mlir::Value>::getEmptyKey(),
70 DenseMapInfo<mlir::Value>::getEmptyKey(),
71 DenseMapInfo<mlir::Attribute>::getEmptyKey()};
72 }
getTombstoneKeyllvm::DenseMapInfo73 static mlir::mhlo::BroadcastIntent getTombstoneKey() {
74 return {DenseMapInfo<mlir::RankedTensorType>::getTombstoneKey(),
75 DenseMapInfo<mlir::Value>::getTombstoneKey(),
76 DenseMapInfo<mlir::Value>::getTombstoneKey(),
77 DenseMapInfo<mlir::Attribute>::getTombstoneKey()};
78 }
getHashValuellvm::DenseMapInfo79 static unsigned getHashValue(const mlir::mhlo::BroadcastIntent &intent) {
80 return hash_combine(
81 DenseMapInfo<mlir::RankedTensorType>::getHashValue(intent.resultType),
82 DenseMapInfo<mlir::Value>::getHashValue(intent.targetValue),
83 DenseMapInfo<mlir::Value>::getHashValue(intent.outputDimensions),
84 DenseMapInfo<mlir::Attribute>::getHashValue(
85 intent.broadcastDimensions));
86 }
isEqualllvm::DenseMapInfo87 static bool isEqual(const mlir::mhlo::BroadcastIntent &lhs,
88 const mlir::mhlo::BroadcastIntent &rhs) {
89 return lhs == rhs;
90 }
91 };
92
93 } // namespace llvm
94
95 namespace mlir {
96 namespace mhlo {
97 namespace {
98
allowsForElementwiseBroadcastPropagation(Operation * op)99 bool allowsForElementwiseBroadcastPropagation(Operation *op) {
100 if (op && op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() &&
101 op->hasTrait<mlir::OpTrait::Elementwise>() && op->getNumResults() == 1) {
102 return true;
103 }
104 if (op && op->hasTrait<mlir::mhlo::OpTrait::BroadcastingElementwise>() &&
105 op->getNumResults() == 1) {
106 return true;
107 }
108 return false;
109 }
110
allowsForBroadcastPropagation(Operation * op)111 bool allowsForBroadcastPropagation(Operation *op) {
112 return llvm::isa_and_nonnull<DynamicBroadcastInDimOp>(op) ||
113 allowsForElementwiseBroadcastPropagation(op);
114 }
115
composeBroadcastDimensionsAttr(OpBuilder & builder,DenseIntElementsAttr a,DenseIntElementsAttr b)116 DenseIntElementsAttr composeBroadcastDimensionsAttr(OpBuilder &builder,
117 DenseIntElementsAttr a,
118 DenseIntElementsAttr b) {
119 SmallVector<int64_t> bVec =
120 llvm::to_vector(llvm::map_range(b, [](const APInt &it) {
121 return static_cast<int64_t>(it.getLimitedValue());
122 }));
123 SmallVector<int64_t> composedVec = llvm::to_vector(llvm::map_range(
124 a, [bVec](const APInt &it) { return bVec[it.getLimitedValue()]; }));
125 return builder.getI64TensorAttr(composedVec);
126 }
127
128 // Find all the broadcast intents and their dependencies. Start analyzing from
129 // the root an collect all broadcast intents that can help broadcast propagation
130 // from there.
findBroadcastIntents(DynamicBroadcastInDimOp root,Block * parentBlock,BroadcastIntent & rootBcastIntent,SmallVector<BroadcastIntent> & bcastIntents,DenseMap<BroadcastIntent,SmallVector<BroadcastIntent>> & bcastIntentDependencies)131 void findBroadcastIntents(
132 DynamicBroadcastInDimOp root, Block *parentBlock,
133 BroadcastIntent &rootBcastIntent,
134 SmallVector<BroadcastIntent> &bcastIntents,
135 DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
136 &bcastIntentDependencies) {
137 OpBuilder builder(root.getContext());
138
139 // Use the result vector of broadcast intents as a worklist. The set of
140 // broadcast intents helps to ensure their uniqueness.
141 DenseSet<BroadcastIntent> bcastIntentsSet;
142 auto addToWorklistIfNew = [&](BroadcastIntent bcastIntent) {
143 if (!bcastIntentsSet.count(bcastIntent)) {
144 bcastIntentsSet.insert(bcastIntent);
145 bcastIntents.push_back(bcastIntent);
146 }
147 };
148
149 // Derive the broadcast intent associated with the root broadcast operation.
150 // Add it to the worklist to seed the analysis.
151 rootBcastIntent = {root.getResult().getType().cast<RankedTensorType>(),
152 root.operand(), root.output_dimensions(),
153 root.broadcast_dimensions()};
154 addToWorklistIfNew(rootBcastIntent);
155
156 // We use result vector of broadcast intents as a worklist, the first `i`
157 // intents of which have been processed.
158 for (int64_t i = 0; i < static_cast<int64_t>(bcastIntents.size()); ++i) {
159 BroadcastIntent it = bcastIntents[i];
160 Operation *producerOp = it.targetValue.getDefiningOp();
161
162 // We can propagate broadcasts over (broadcasting) element-wise operations
163 // and dynamic_broadcast_in_dim ops with the restriction that they must be
164 // in the same block as they may depend on assuming regions.
165 if (!producerOp || producerOp->getBlock() != parentBlock ||
166 !allowsForBroadcastPropagation(producerOp)) {
167 continue;
168 }
169
170 // We can skip broadcasting producers (dynamic_broadcast_in_dim ops) if we
171 // compose their broadcasting dimensions.
172 if (auto producerBcastOp =
173 llvm::dyn_cast<DynamicBroadcastInDimOp>(producerOp)) {
174 DenseIntElementsAttr composedBcastDims = composeBroadcastDimensionsAttr(
175 builder, producerBcastOp.broadcast_dimensions(),
176 it.broadcastDimensions.cast<DenseIntElementsAttr>());
177 BroadcastIntent bcastedOperandIntent = {
178 it.resultType, producerBcastOp.operand(), it.outputDimensions,
179 composedBcastDims};
180
181 // Record dependency and "recur".
182 bcastIntentDependencies[it] = {bcastedOperandIntent};
183 addToWorklistIfNew(bcastedOperandIntent);
184 continue;
185 }
186
187 // We can propagate broadcasts over (broadcasting) element-wise operations.
188 // Instead of broadcasting the result of such an op, we can broadcast the
189 // operands and apply the element-wise operation to them.
190 assert(allowsForElementwiseBroadcastPropagation(producerOp));
191 bcastIntentDependencies[it] = {};
192 for (auto operand : producerOp->getOperands()) {
193 auto operandTy = operand.getType().cast<RankedTensorType>();
194 auto operandBcastDims = operandTy.getRank() == 0
195 ? builder.getI64TensorAttr({})
196 : it.broadcastDimensions;
197 auto bcastedOperandTy = RankedTensorType::get(it.resultType.getShape(),
198 operandTy.getElementType());
199 BroadcastIntent bcastedOperandIntent = {
200 bcastedOperandTy, operand, it.outputDimensions, operandBcastDims};
201
202 // Record dependency and "recur".
203 bcastIntentDependencies[it].push_back(bcastedOperandIntent);
204 addToWorklistIfNew(bcastedOperandIntent);
205 }
206 }
207 }
208
sortBroadcastIntentsInReverseTopologicalOrder(SmallVector<BroadcastIntent> & bcastIntentsVec,Block * parentBlock)209 void sortBroadcastIntentsInReverseTopologicalOrder(
210 SmallVector<BroadcastIntent> &bcastIntentsVec, Block *parentBlock) {
211 // Sort broadcast intents in reverse topological order of the producer ops. We
212 // can use the positions in the block for this. All broadcast intents outside
213 // the block (e.g. arguments) will be sorted towards the front.
214 // This ordering is independent of the output dimensions as dependencies can
215 // only occur between broadcast intents of the same output dimension.
216 std::sort(bcastIntentsVec.begin(), bcastIntentsVec.end(),
217 [parentBlock](const BroadcastIntent &a, const BroadcastIntent &b) {
218 Operation *producerOpA = a.targetValue.getDefiningOp();
219 Operation *producerOpB = b.targetValue.getDefiningOp();
220 bool aInBlock = producerOpA != nullptr &&
221 producerOpA->getBlock() == parentBlock;
222 bool bInBlock = producerOpB != nullptr &&
223 producerOpB->getBlock() == parentBlock;
224 if (aInBlock && bInBlock) {
225 return producerOpA->isBeforeInBlock(producerOpB);
226 }
227 return !aInBlock && bInBlock;
228 });
229 }
230
setInsertionPointToEarliestPointWithAllValuesAvailable(PatternRewriter & rewriter,Block * block,ValueRange values)231 void setInsertionPointToEarliestPointWithAllValuesAvailable(
232 PatternRewriter &rewriter, Block *block, ValueRange values) {
233 Operation *lastDef = nullptr;
234 for (Value v : values) {
235 Operation *def = v.getDefiningOp();
236 if (def && def->getBlock() == block) {
237 if (!lastDef || lastDef->isBeforeInBlock(def)) lastDef = def;
238 }
239 }
240 if (lastDef) {
241 rewriter.setInsertionPointAfter(lastDef);
242 } else {
243 rewriter.setInsertionPointToStart(block);
244 }
245 }
246
realizeBroadcastIntents(SmallVector<BroadcastIntent> & sortedBcastIntents,DenseMap<BroadcastIntent,SmallVector<BroadcastIntent>> & bcastIntentDependencies,Block * parentBlock,PatternRewriter & rewriter)247 DenseMap<BroadcastIntent, Value> realizeBroadcastIntents(
248 SmallVector<BroadcastIntent> &sortedBcastIntents,
249 DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
250 &bcastIntentDependencies,
251 Block *parentBlock, PatternRewriter &rewriter) {
252 // Realize broadcast intents in order. They must be sorted so that their
253 // dependencies are realized before them.
254 DenseMap<BroadcastIntent, Value> realizations;
255 for (auto it : sortedBcastIntents) {
256 Operation *producerOp = it.targetValue.getDefiningOp();
257 assert(!realizations.count(it) && "expect unrealized broadcast intent");
258 auto deps = bcastIntentDependencies.find(it);
259
260 // If we cannot propagate broadcasts further, materialize them as a
261 // dynamic_broadcast_in_dim op.
262 if (!producerOp || producerOp->getBlock() != parentBlock ||
263 !allowsForBroadcastPropagation(producerOp)) {
264 assert(deps == bcastIntentDependencies.end() && "expect no dependencies");
265 setInsertionPointToEarliestPointWithAllValuesAvailable(
266 rewriter, parentBlock,
267 ValueRange{it.targetValue, it.outputDimensions});
268 realizations[it] = rewriter.create<DynamicBroadcastInDimOp>(
269 it.targetValue.getLoc(), it.resultType, it.targetValue,
270 it.outputDimensions,
271 it.broadcastDimensions.cast<DenseIntElementsAttr>());
272 continue;
273 }
274
275 // For broadcast propagation across dynamic_broadcast_in_dim ops, the
276 // broadcasted value is already materialized. Forward it.
277 if (auto producerBcastOp =
278 llvm::dyn_cast_or_null<DynamicBroadcastInDimOp>(producerOp)) {
279 assert(deps != bcastIntentDependencies.end() &&
280 deps->second.size() == 1 && "expect one dependency");
281 auto bcastedOperand = realizations.find(deps->second.front());
282 assert(bcastedOperand != realizations.end());
283 realizations[it] = Value(bcastedOperand->second);
284 continue;
285 }
286
287 // Othwerwise, realize broadcast intent for a (broadcasting) element-wise
288 // operation based on the broadcasted operands.
289 assert(allowsForElementwiseBroadcastPropagation(producerOp) &&
290 "expect broadcast propagation over an (broadcasting) element-wise "
291 "operation");
292 assert(deps != bcastIntentDependencies.end() &&
293 deps->second.size() == producerOp->getNumOperands() &&
294 "expect one dependency per operand");
295 auto bcastedOperands = llvm::to_vector(
296 llvm::map_range(deps->second, [&](BroadcastIntent operandIntent) {
297 auto bcastedOperand = realizations.find(operandIntent);
298 assert(bcastedOperand != realizations.end() &&
299 "expect dependencies to be realized earlier");
300 return bcastedOperand->second;
301 }));
302 setInsertionPointToEarliestPointWithAllValuesAvailable(
303 rewriter, parentBlock, bcastedOperands);
304 OperationState newProducerOpState(
305 producerOp->getLoc(), producerOp->getName().getStringRef(),
306 bcastedOperands, it.resultType, producerOp->getAttrs());
307 Operation *newProducerOp = rewriter.create(newProducerOpState);
308 assert(newProducerOp->getNumResults() == 1 && "expect exactly one result");
309 realizations[it] = newProducerOp->getResults().front();
310 }
311
312 return realizations;
313 }
314
transitivelyEraseUnusedSideEffectFreeOps(Operation * root,PatternRewriter & rewriter)315 void transitivelyEraseUnusedSideEffectFreeOps(Operation *root,
316 PatternRewriter &rewriter) {
317 // Find ops to erase.
318 SmallPtrSet<Operation *, 16> opsToEraseSet;
319 SmallVector<Operation *, 16> opsToErase;
320 SmallVector<Operation *, 16> worklist = {root};
321 while (!worklist.empty()) {
322 Operation *op = worklist.pop_back_val();
323
324 // Erase ops only once.
325 if (opsToEraseSet.count(op)) continue;
326
327 // Erase only operations that are unused and free of side effects.
328 if (!MemoryEffectOpInterface::hasNoEffect(op) ||
329 !llvm::all_of(op->getUsers(), [opsToEraseSet](Operation *user) {
330 return opsToEraseSet.count(user);
331 })) {
332 continue;
333 }
334
335 // Erase and "recur".
336 opsToEraseSet.insert(op);
337 opsToErase.push_back(op);
338 for (Value operand : op->getOperands()) {
339 if (Operation *def = operand.getDefiningOp()) worklist.push_back(def);
340 }
341 }
342
343 // Finally, erase the ops in the order of their uses.
344 for (Operation *op : opsToErase) rewriter.eraseOp(op);
345 }
346
propagateBroadcast(DynamicBroadcastInDimOp root,Block * parentBlock,PatternRewriter & rewriter)347 LogicalResult propagateBroadcast(DynamicBroadcastInDimOp root,
348 Block *parentBlock,
349 PatternRewriter &rewriter) {
350 // We can move broadcasts up over (i) (broadcasting) element-wise operations
351 // and (i) dynamic_broadcast_in_dim ops. This way, we propagate them through
352 // the IR to perform them early. Instead of broadcasting the result of such an
353 // op, we can broadcast the operands and apply the element-wise operation to
354 // them.
355 //
356 // To avoid exponential growth of the IR, we will do this in two phases:
357 // 1) First, we collect all the unique broadcast intents. These are
358 // broadcasted versions of values that we are interested in. They may
359 // later be materialized as an explicit broadcast or they can be the
360 // direct result of an operation over which a broadcast was propagated.
361 // 2) Then, we fulfill every broadcast intent in reverse topological order
362 // to ensure that their dependencies (the broadcasted operands) are
363 // available.
364
365 // Find the unique broadcast intents.
366 BroadcastIntent rootBcastIntent;
367 SmallVector<BroadcastIntent> bcastIntents;
368 DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
369 bcastIntentDependencies;
370 findBroadcastIntents(root, parentBlock, rootBcastIntent, bcastIntents,
371 bcastIntentDependencies);
372
373 // Fail if there is nothing but the root intent, i.e. if there is nothing to
374 // rewrite here.
375 if (bcastIntents.size() <= 1) {
376 assert(bcastIntents.front() == rootBcastIntent && "expect root intent");
377 return failure();
378 }
379
380 // Sort the broadcast intents in reverse topological order so that they can be
381 // materialized and every depency is available when needed.
382 sortBroadcastIntentsInReverseTopologicalOrder(bcastIntents, parentBlock);
383
384 // Realize broadcast intents.
385 DenseMap<BroadcastIntent, Value> realizations = realizeBroadcastIntents(
386 bcastIntents, bcastIntentDependencies, parentBlock, rewriter);
387
388 // Find the operations that may become redundant after replacing the root
389 // operation. This allows us to transitively erase unused side effect-free
390 // operations that result from this rewrite (after the root operation is no
391 // longer accessible).
392 SmallVector<Operation *> possiblyUnused;
393 for (auto operand : root->getOperands()) {
394 if (Operation *def = operand.getDefiningOp()) possiblyUnused.push_back(def);
395 }
396
397 // Replace the root operation with its broadcast intent's realization.
398 rewriter.replaceOp(root, realizations[rootBcastIntent]);
399
400 // Erase all the operations that have become redundant as a result of this
401 // rewrite.
402 for (Operation *op : possiblyUnused) {
403 transitivelyEraseUnusedSideEffectFreeOps(op, rewriter);
404 }
405
406 return success();
407 }
408
409 struct BroadcastPropagationPattern
410 : public OpRewritePattern<DynamicBroadcastInDimOp> {
411 using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
412
matchAndRewritemlir::mhlo::__anond615a9400211::BroadcastPropagationPattern413 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
414 PatternRewriter &rewriter) const override {
415 return propagateBroadcast(op, op->getBlock(), rewriter);
416 }
417 };
418
419 struct BroadcastPropagationPass
420 : public BroadcastPropagationPassBase<BroadcastPropagationPass> {
getDependentDialectsmlir::mhlo::__anond615a9400211::BroadcastPropagationPass421 void getDependentDialects(DialectRegistry ®istry) const override {
422 registry.insert<mhlo::MhloDialect>();
423 }
424
runOnOperationmlir::mhlo::__anond615a9400211::BroadcastPropagationPass425 void runOnOperation() override {
426 MLIRContext *ctx = &getContext();
427
428 // Collect patterns.
429 RewritePatternSet patterns(ctx);
430 patterns.add<BroadcastPropagationPattern>(ctx);
431
432 // Apply broadcast propagation in reverse order to start propagation at
433 // the root of broadcast chains. This avoids duplicate work.
434 GreedyRewriteConfig config;
435 config.useTopDownTraversal = false;
436
437 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
438 config))) {
439 return signalPassFailure();
440 }
441 }
442 };
443
444 } // namespace
445
createBroadcastPropagationPass()446 std::unique_ptr<OperationPass<func::FuncOp>> createBroadcastPropagationPass() {
447 return std::make_unique<BroadcastPropagationPass>();
448 }
449
450 } // namespace mhlo
451 } // namespace mlir
452