1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Shape/IR/Shape.h"
28 #include "mlir/Pass/Pass.h"
29
30 namespace mlir {
31 namespace mhlo {
32 namespace {
33
34 enum class CstrBroadcastableOperandKind {
35 kValue = 0,
36 kShapeOfValue = 1,
37 };
38
39 struct CstrBroadcastableOperand {
valueOfmlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand40 static CstrBroadcastableOperand valueOf(BlockArgument barg) {
41 return {CstrBroadcastableOperandKind::kValue, barg};
42 }
shapeOfmlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand43 static CstrBroadcastableOperand shapeOf(BlockArgument barg) {
44 return {CstrBroadcastableOperandKind::kShapeOfValue, barg};
45 }
46
47 // An arbitrary but well define order.
operator <mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand48 inline bool operator<(const CstrBroadcastableOperand &rhs) const {
49 if (kind != rhs.kind) return kind < rhs.kind;
50 return value.getArgNumber() < rhs.value.getArgNumber();
51 }
operator >mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand52 inline bool operator>(const CstrBroadcastableOperand &rhs) const {
53 return rhs < *this;
54 }
operator <=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand55 inline bool operator<=(const CstrBroadcastableOperand &rhs) const {
56 return !(*this > rhs);
57 }
operator >=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand58 inline bool operator>=(const CstrBroadcastableOperand &rhs) const {
59 return !(*this < rhs);
60 }
61
62 // Equality.
operator ==mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand63 inline bool operator==(const CstrBroadcastableOperand &rhs) const {
64 return kind == rhs.kind && value == rhs.value;
65 }
operator !=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand66 inline bool operator!=(const CstrBroadcastableOperand &rhs) const {
67 return !(*this == rhs);
68 }
69
70 CstrBroadcastableOperandKind kind;
71 BlockArgument value;
72 };
73
74 struct CstrBroadcastableIntent {
CstrBroadcastableIntentmlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent75 explicit CstrBroadcastableIntent(Location loc) : loc(loc) {}
76
77 // A well defined order that sorts weaker constraints to the front.
operator <mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent78 inline bool operator<(const CstrBroadcastableIntent &rhs) const {
79 // Sort weaker constraints to the front.
80 if (operands.size() != rhs.operands.size())
81 return operands.size() < rhs.operands.size();
82
83 return operands < rhs.operands;
84 }
operator >mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent85 inline bool operator>(const CstrBroadcastableIntent &rhs) const {
86 return rhs < *this;
87 }
operator <=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent88 inline bool operator<=(const CstrBroadcastableIntent &rhs) const {
89 return !(*this > rhs);
90 }
operator >=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent91 inline bool operator>=(const CstrBroadcastableIntent &rhs) const {
92 return !(*this < rhs);
93 }
94
operator ==mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent95 inline bool operator==(const CstrBroadcastableIntent &rhs) const {
96 return operands == rhs.operands;
97 }
operator !=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent98 inline bool operator!=(const CstrBroadcastableIntent &rhs) const {
99 return !(*this == rhs);
100 }
101
102 Location loc;
103 SmallVector<CstrBroadcastableOperand> operands;
104 };
105
canonicalizeBroadcastabilityCstrs(SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)106 void canonicalizeBroadcastabilityCstrs(
107 SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
108 // Sort inner constraint arguments and eliminate duplicates.
109 for (auto &it : broadcastabilityCstrs) {
110 llvm::sort(it.operands);
111 auto *newEnd =
112 llvm::unique(it.operands, [](auto a, auto b) { return a == b; });
113 it.operands.erase(newEnd, it.operands.end());
114 }
115
116 // Sort broadcastability constraints and sort the strongest to the front.
117 llvm::sort(broadcastabilityCstrs, std::greater<>());
118
119 // Remove broadcastability constraints if they are implied by stronger
120 // constraints.
121 for (int64_t i = 0; i < static_cast<int64_t>(broadcastabilityCstrs.size());
122 i++) {
123 CstrBroadcastableIntent &strongCstr = broadcastabilityCstrs[i];
124 auto *newEnd = std::remove_if(
125 broadcastabilityCstrs.begin() + i + 1, broadcastabilityCstrs.end(),
126 [strongCstr](CstrBroadcastableIntent weakerCstr) {
127 assert(weakerCstr.operands.size() <= strongCstr.operands.size() &&
128 "only look at possibly weaker broadcastability constraints");
129 return std::includes(
130 strongCstr.operands.begin(), strongCstr.operands.end(),
131 weakerCstr.operands.begin(), weakerCstr.operands.end());
132 });
133 broadcastabilityCstrs.erase(newEnd, broadcastabilityCstrs.end());
134 }
135 }
136
eliminateDuplicateBlockArguments(SmallVector<BlockArgument> & bargs)137 void eliminateDuplicateBlockArguments(SmallVector<BlockArgument> &bargs) {
138 llvm::sort(bargs, [](auto a, auto b) {
139 return a.getArgNumber() < b.getArgNumber();
140 });
141 auto *newEnd = llvm::unique(bargs, [](auto a, auto b) { return a == b; });
142 bargs.erase(newEnd, bargs.end());
143 }
144
inlineAssumingRegions(Block * theBlock)145 void inlineAssumingRegions(Block *theBlock) {
146 theBlock->walk([](shape::AssumingOp aop) {
147 Block *body = aop.getBody();
148 auto yop = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
149 aop->getBlock()->getOperations().splice(aop->getIterator(),
150 body->getOperations());
151 aop.replaceAllUsesWith(yop.getOperands());
152 yop.erase();
153 aop.erase();
154 });
155 }
156
materializeFusedConstraints(Location loc,OpBuilder & builder,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)157 Value materializeFusedConstraints(
158 Location loc, OpBuilder &builder, SmallVector<BlockArgument> &argumentCstrs,
159 SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
160 // Ensure to materialize shape_of only once.
161 DenseMap<Value, Value> shapeOfMaterializations;
162 auto getShapeOfMaterialization = [&](Value arg) {
163 auto it = shapeOfMaterializations.find(arg);
164 if (it != shapeOfMaterializations.end()) return it->second;
165 auto shapeOf = builder.create<shape::ShapeOfOp>(loc, arg).getResult();
166 shapeOfMaterializations[arg] = shapeOf;
167 return shapeOf;
168 };
169
170 SmallVector<Value> witnesses;
171 witnesses.reserve(argumentCstrs.size() + broadcastabilityCstrs.size());
172
173 // Carry over the argument witnesses.
174 for (BlockArgument it : argumentCstrs) witnesses.push_back(it);
175
176 // Materialize broadcastability constraints.
177 for (const CstrBroadcastableIntent &it : broadcastabilityCstrs) {
178 auto shapes = llvm::to_vector<8>(llvm::map_range(
179 it.operands,
180 [getShapeOfMaterialization](const CstrBroadcastableOperand &operand) {
181 if (operand.kind == CstrBroadcastableOperandKind::kShapeOfValue) {
182 return getShapeOfMaterialization(operand.value);
183 }
184 assert(operand.kind == CstrBroadcastableOperandKind::kValue);
185 Value shape = operand.value;
186 return shape;
187 }));
188 auto cstr = builder.create<shape::CstrBroadcastableOp>(it.loc, shapes);
189 witnesses.push_back(cstr);
190 }
191 if (witnesses.size() == 1) return witnesses.front();
192 return builder.create<shape::AssumingAllOp>(loc, witnesses);
193 }
194
materializeBlockGlobalConstraintFusion(Location loc,OpBuilder & builder,Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)195 void materializeBlockGlobalConstraintFusion(
196 Location loc, OpBuilder &builder, Block *theBlock,
197 llvm::SmallSetVector<Operation *, 16> &toBeErased,
198 SmallVector<BlockArgument> &argumentCstrs,
199 SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
200 // Eliminate the old assuming regions and inline their ops into the main
201 // function body.
202 inlineAssumingRegions(theBlock);
203
204 // Delete ops that are known to have become redundant by inlining of assuming
205 // regions.
206 for (auto *it : toBeErased) it->erase();
207
208 // Materialize fused constraints at the beginning of the function.
209 builder.setInsertionPointToStart(theBlock);
210 Value fusedCstr = materializeFusedConstraints(loc, builder, argumentCstrs,
211 broadcastabilityCstrs);
212
213 // Create fused assuming region with empty body.
214 Operation *theBlockTerminator = theBlock->getTerminator();
215 auto fusedAop = builder.create<shape::AssumingOp>(
216 loc, theBlockTerminator->getOperandTypes(), fusedCstr);
217 auto *fusedAopBody = new Block;
218 fusedAop.getDoRegion().getBlocks().push_back(fusedAopBody);
219
220 // Splice all the original block's operations into the fused assuming region's
221 // body (except for the block terminator).
222 auto &dstBlocks = fusedAopBody->getOperations();
223 dstBlocks.splice(dstBlocks.begin(), theBlock->getOperations(),
224 builder.getInsertionPoint(),
225 theBlockTerminator->getIterator());
226
227 // Yield results from the assuming region and pass them on to the original
228 // block terminator.
229 builder.setInsertionPointToEnd(fusedAopBody);
230 builder.create<shape::AssumingYieldOp>(loc,
231 theBlockTerminator->getOperands());
232 theBlockTerminator->setOperands(fusedAop.getResults());
233 }
234
isRemainingUse(OpOperand & use,Block * theBlock,llvm::SmallSetVector<Operation *,16> & considerDead)235 bool isRemainingUse(OpOperand &use, Block *theBlock,
236 llvm::SmallSetVector<Operation *, 16> &considerDead) {
237 Operation *op = use.getOwner();
238
239 // Not a real use if user is considered dead.
240 if (considerDead.count(op)) return false;
241
242 // Assuming regions in the regarded block are not a real use as they will be
243 // inlined.
244 if (auto aop = llvm::dyn_cast<shape::AssumingOp>(op))
245 return aop->getBlock() == theBlock;
246
247 // Look through assuming regions' yield ops.
248 if (auto yop = llvm::dyn_cast<shape::AssumingYieldOp>(op)) {
249 auto aop = yop->getParentOfType<shape::AssumingOp>();
250 auto outerResult = aop.getResults()[use.getOperandNumber()];
251 return llvm::all_of(outerResult.getUses(), [&](auto &outerUse) {
252 return isRemainingUse(outerUse, theBlock, considerDead);
253 });
254 }
255
256 // Otherwise, consider it a real use.
257 return true;
258 }
259
tryFlagForErase(Block * theBlock,Operation * op,llvm::SmallSetVector<Operation *,16> & toBeErased)260 void tryFlagForErase(Block *theBlock, Operation *op,
261 llvm::SmallSetVector<Operation *, 16> &toBeErased) {
262 if (llvm::none_of(op->getUses(), [&](auto &use) {
263 return isRemainingUse(use, theBlock, toBeErased);
264 })) {
265 toBeErased.insert(op);
266 }
267 }
268
isWithinBlock(Operation * op,Block * theBlock)269 bool isWithinBlock(Operation *op, Block *theBlock) {
270 while (op != nullptr && op->getBlock() != theBlock) op = op->getParentOp();
271 return op != nullptr;
272 }
273
analyzeBroadcastableConstraint(shape::CstrBroadcastableOp cstrBcastable,Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<CstrBroadcastableOperand> & transitiveBcastableCstrOperands)274 LogicalResult analyzeBroadcastableConstraint(
275 shape::CstrBroadcastableOp cstrBcastable, Block *theBlock,
276 llvm::SmallSetVector<Operation *, 16> &toBeErased,
277 SmallVector<CstrBroadcastableOperand> &transitiveBcastableCstrOperands) {
278 SmallVector<Value> worklist = cstrBcastable.getShapes();
279 while (!worklist.empty()) {
280 Value shape = worklist.pop_back_val();
281 Operation *def = shape.getDefiningOp();
282
283 // For shapes without a definition, expect them to be an argument of the
284 // regarded block.
285 if (def == nullptr) {
286 auto barg = shape.dyn_cast<BlockArgument>();
287 if (!barg || barg.getParentBlock() != theBlock) return failure();
288 transitiveBcastableCstrOperands.push_back(
289 CstrBroadcastableOperand::valueOf(barg));
290 continue;
291 }
292
293 // For shape_of ops, expect them to wrap an argument of the regarded block.
294 // The shape reification pass helps achieve this, which should be run before
295 // this pass.
296 if (auto sof = llvm::dyn_cast<shape::ShapeOfOp>(def)) {
297 if (!isWithinBlock(sof, theBlock)) return failure();
298 tryFlagForErase(theBlock, def, toBeErased);
299 auto barg = sof.getArg().dyn_cast<BlockArgument>();
300 if (!barg) return failure();
301 transitiveBcastableCstrOperands.push_back(
302 CstrBroadcastableOperand::shapeOf(barg));
303 continue;
304 }
305
306 // For broadcast ops, broadcastability of the operands is an implicit
307 // requirement. We can online the operands.
308 if (auto bcast = llvm::dyn_cast<shape::BroadcastOp>(def)) {
309 if (!isWithinBlock(bcast, theBlock)) return failure();
310 tryFlagForErase(theBlock, def, toBeErased);
311 auto bcastShapes = bcast.getShapes();
312 worklist.append(bcastShapes.begin(), bcastShapes.end());
313 continue;
314 }
315
316 // Look into assuming ops to proceed.
317 if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
318 if (!isWithinBlock(aop, theBlock)) return failure();
319 auto yieldOp =
320 llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
321 size_t i = llvm::find(aop.getResults(), shape).getIndex();
322 Value innerShape = yieldOp.getOperand(i);
323 worklist.push_back(innerShape);
324 continue;
325 }
326
327 // Otherwise, bail.
328 return failure();
329 }
330
331 return success();
332 }
333
analyzeBlockGlobalConstraints(Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)334 LogicalResult analyzeBlockGlobalConstraints(
335 Block *theBlock, llvm::SmallSetVector<Operation *, 16> &toBeErased,
336 SmallVector<BlockArgument> &argumentCstrs,
337 SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
338 // Find all the assuming regions and start the search for reachable
339 // constraints from there.
340 SmallVector<Value> cstrWorklist;
341 theBlock->walk(
342 [&](shape::AssumingOp aop) { cstrWorklist.push_back(aop.getWitness()); });
343
344 while (!cstrWorklist.empty()) {
345 Value cstr = cstrWorklist.pop_back_val();
346 Operation *def = cstr.getDefiningOp();
347
348 // For witnesses without a definition, expect it to be an argument of the
349 // regarded block.
350 if (def == nullptr) {
351 auto barg = cstr.dyn_cast<BlockArgument>();
352 if (!barg || barg.getParentBlock() != theBlock) return failure();
353 argumentCstrs.push_back(barg);
354 continue;
355 }
356
357 // For conjunctions, continue with the operands.
358 if (auto aaop = llvm::dyn_cast<shape::AssumingAllOp>(def)) {
359 if (!isWithinBlock(aaop, theBlock)) return failure();
360 tryFlagForErase(theBlock, def, toBeErased);
361 auto aaopCstrs = aaop.getOperands();
362 cstrWorklist.append(aaopCstrs.begin(), aaopCstrs.end());
363 continue;
364 }
365
366 // For broadcastable constraints, find the transitively included shape
367 // operands.
368 if (auto cstrBcastable = llvm::dyn_cast<shape::CstrBroadcastableOp>(def)) {
369 if (!isWithinBlock(cstrBcastable, theBlock)) return failure();
370 tryFlagForErase(theBlock, def, toBeErased);
371 CstrBroadcastableIntent bcastableIntent(cstrBcastable.getLoc());
372 if (failed(analyzeBroadcastableConstraint(
373 cstrBcastable, theBlock, toBeErased, bcastableIntent.operands))) {
374 return failure();
375 }
376 broadcastabilityCstrs.push_back(bcastableIntent);
377 continue;
378 }
379
380 // Look into assuming regions when running into them. They will be inlined
381 // later.
382 if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
383 if (!isWithinBlock(aop, theBlock)) return failure();
384 size_t i = llvm::find(aop.getResults(), cstr).getIndex();
385 auto yieldOp =
386 llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
387 cstrWorklist.push_back(yieldOp.getOperand(i));
388 continue;
389 }
390
391 // Otherwise, bail.
392 return failure();
393 }
394
395 return success();
396 }
397
fuseBlockGlobalConstraints(Location loc,OpBuilder & builder,Block * theBlock)398 LogicalResult fuseBlockGlobalConstraints(Location loc, OpBuilder &builder,
399 Block *theBlock) {
400 // Analyze block-global constraints.
401 SmallVector<BlockArgument> argumentCstrs;
402 SmallVector<CstrBroadcastableIntent> broadcastabilityCstrs;
403 llvm::SmallSetVector<Operation *, 16> toBeErased;
404 if (failed(analyzeBlockGlobalConstraints(theBlock, toBeErased, argumentCstrs,
405 broadcastabilityCstrs))) {
406 return failure();
407 }
408
409 // Return early if there is nothing to do.
410 if (argumentCstrs.empty() && broadcastabilityCstrs.empty()) {
411 return success();
412 }
413
414 // Simplify constraints.
415 eliminateDuplicateBlockArguments(argumentCstrs);
416 canonicalizeBroadcastabilityCstrs(broadcastabilityCstrs);
417
418 // Materialize constraint fusion.
419 materializeBlockGlobalConstraintFusion(loc, builder, theBlock, toBeErased,
420 argumentCstrs, broadcastabilityCstrs);
421
422 return success();
423 }
424
425 struct ConstraintFusionPass
426 : public ConstraintFusionPassBase<ConstraintFusionPass> {
getDependentDialectsmlir::mhlo::__anon08b4e7980111::ConstraintFusionPass427 void getDependentDialects(DialectRegistry ®istry) const override {
428 registry.insert<shape::ShapeDialect>();
429 }
430
runOnOperationmlir::mhlo::__anon08b4e7980111::ConstraintFusionPass431 void runOnOperation() override {
432 func::FuncOp f = getOperation();
433 auto loc = f.getLoc();
434 OpBuilder builder(&getContext());
435 for (auto &block : f.getBody().getBlocks()) {
436 if (failed(fuseBlockGlobalConstraints(loc, builder, &block)))
437 return signalPassFailure();
438 }
439 }
440 };
441
442 } // namespace
443
createConstraintFusionPass()444 std::unique_ptr<OperationPass<func::FuncOp>> createConstraintFusionPass() {
445 return std::make_unique<ConstraintFusionPass>();
446 }
447
448 } // namespace mhlo
449 } // namespace mlir
450