1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Shape/IR/Shape.h"
28 #include "mlir/Pass/Pass.h"
29 
30 namespace mlir {
31 namespace mhlo {
32 namespace {
33 
34 enum class CstrBroadcastableOperandKind {
35   kValue = 0,
36   kShapeOfValue = 1,
37 };
38 
39 struct CstrBroadcastableOperand {
valueOfmlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand40   static CstrBroadcastableOperand valueOf(BlockArgument barg) {
41     return {CstrBroadcastableOperandKind::kValue, barg};
42   }
shapeOfmlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand43   static CstrBroadcastableOperand shapeOf(BlockArgument barg) {
44     return {CstrBroadcastableOperandKind::kShapeOfValue, barg};
45   }
46 
47   // An arbitrary but well define order.
operator <mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand48   inline bool operator<(const CstrBroadcastableOperand &rhs) const {
49     if (kind != rhs.kind) return kind < rhs.kind;
50     return value.getArgNumber() < rhs.value.getArgNumber();
51   }
operator >mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand52   inline bool operator>(const CstrBroadcastableOperand &rhs) const {
53     return rhs < *this;
54   }
operator <=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand55   inline bool operator<=(const CstrBroadcastableOperand &rhs) const {
56     return !(*this > rhs);
57   }
operator >=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand58   inline bool operator>=(const CstrBroadcastableOperand &rhs) const {
59     return !(*this < rhs);
60   }
61 
62   // Equality.
operator ==mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand63   inline bool operator==(const CstrBroadcastableOperand &rhs) const {
64     return kind == rhs.kind && value == rhs.value;
65   }
operator !=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableOperand66   inline bool operator!=(const CstrBroadcastableOperand &rhs) const {
67     return !(*this == rhs);
68   }
69 
70   CstrBroadcastableOperandKind kind;
71   BlockArgument value;
72 };
73 
74 struct CstrBroadcastableIntent {
CstrBroadcastableIntentmlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent75   explicit CstrBroadcastableIntent(Location loc) : loc(loc) {}
76 
77   // A well defined order that sorts weaker constraints to the front.
operator <mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent78   inline bool operator<(const CstrBroadcastableIntent &rhs) const {
79     // Sort weaker constraints to the front.
80     if (operands.size() != rhs.operands.size())
81       return operands.size() < rhs.operands.size();
82 
83     return operands < rhs.operands;
84   }
operator >mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent85   inline bool operator>(const CstrBroadcastableIntent &rhs) const {
86     return rhs < *this;
87   }
operator <=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent88   inline bool operator<=(const CstrBroadcastableIntent &rhs) const {
89     return !(*this > rhs);
90   }
operator >=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent91   inline bool operator>=(const CstrBroadcastableIntent &rhs) const {
92     return !(*this < rhs);
93   }
94 
operator ==mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent95   inline bool operator==(const CstrBroadcastableIntent &rhs) const {
96     return operands == rhs.operands;
97   }
operator !=mlir::mhlo::__anon08b4e7980111::CstrBroadcastableIntent98   inline bool operator!=(const CstrBroadcastableIntent &rhs) const {
99     return !(*this == rhs);
100   }
101 
102   Location loc;
103   SmallVector<CstrBroadcastableOperand> operands;
104 };
105 
canonicalizeBroadcastabilityCstrs(SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)106 void canonicalizeBroadcastabilityCstrs(
107     SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
108   // Sort inner constraint arguments and eliminate duplicates.
109   for (auto &it : broadcastabilityCstrs) {
110     llvm::sort(it.operands);
111     auto *newEnd =
112         llvm::unique(it.operands, [](auto a, auto b) { return a == b; });
113     it.operands.erase(newEnd, it.operands.end());
114   }
115 
116   // Sort broadcastability constraints and sort the strongest to the front.
117   llvm::sort(broadcastabilityCstrs, std::greater<>());
118 
119   // Remove broadcastability constraints if they are implied by stronger
120   // constraints.
121   for (int64_t i = 0; i < static_cast<int64_t>(broadcastabilityCstrs.size());
122        i++) {
123     CstrBroadcastableIntent &strongCstr = broadcastabilityCstrs[i];
124     auto *newEnd = std::remove_if(
125         broadcastabilityCstrs.begin() + i + 1, broadcastabilityCstrs.end(),
126         [strongCstr](CstrBroadcastableIntent weakerCstr) {
127           assert(weakerCstr.operands.size() <= strongCstr.operands.size() &&
128                  "only look at possibly weaker broadcastability constraints");
129           return std::includes(
130               strongCstr.operands.begin(), strongCstr.operands.end(),
131               weakerCstr.operands.begin(), weakerCstr.operands.end());
132         });
133     broadcastabilityCstrs.erase(newEnd, broadcastabilityCstrs.end());
134   }
135 }
136 
eliminateDuplicateBlockArguments(SmallVector<BlockArgument> & bargs)137 void eliminateDuplicateBlockArguments(SmallVector<BlockArgument> &bargs) {
138   llvm::sort(bargs, [](auto a, auto b) {
139     return a.getArgNumber() < b.getArgNumber();
140   });
141   auto *newEnd = llvm::unique(bargs, [](auto a, auto b) { return a == b; });
142   bargs.erase(newEnd, bargs.end());
143 }
144 
inlineAssumingRegions(Block * theBlock)145 void inlineAssumingRegions(Block *theBlock) {
146   theBlock->walk([](shape::AssumingOp aop) {
147     Block *body = aop.getBody();
148     auto yop = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
149     aop->getBlock()->getOperations().splice(aop->getIterator(),
150                                             body->getOperations());
151     aop.replaceAllUsesWith(yop.getOperands());
152     yop.erase();
153     aop.erase();
154   });
155 }
156 
materializeFusedConstraints(Location loc,OpBuilder & builder,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)157 Value materializeFusedConstraints(
158     Location loc, OpBuilder &builder, SmallVector<BlockArgument> &argumentCstrs,
159     SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
160   // Ensure to materialize shape_of only once.
161   DenseMap<Value, Value> shapeOfMaterializations;
162   auto getShapeOfMaterialization = [&](Value arg) {
163     auto it = shapeOfMaterializations.find(arg);
164     if (it != shapeOfMaterializations.end()) return it->second;
165     auto shapeOf = builder.create<shape::ShapeOfOp>(loc, arg).getResult();
166     shapeOfMaterializations[arg] = shapeOf;
167     return shapeOf;
168   };
169 
170   SmallVector<Value> witnesses;
171   witnesses.reserve(argumentCstrs.size() + broadcastabilityCstrs.size());
172 
173   // Carry over the argument witnesses.
174   for (BlockArgument it : argumentCstrs) witnesses.push_back(it);
175 
176   // Materialize broadcastability constraints.
177   for (const CstrBroadcastableIntent &it : broadcastabilityCstrs) {
178     auto shapes = llvm::to_vector<8>(llvm::map_range(
179         it.operands,
180         [getShapeOfMaterialization](const CstrBroadcastableOperand &operand) {
181           if (operand.kind == CstrBroadcastableOperandKind::kShapeOfValue) {
182             return getShapeOfMaterialization(operand.value);
183           }
184           assert(operand.kind == CstrBroadcastableOperandKind::kValue);
185           Value shape = operand.value;
186           return shape;
187         }));
188     auto cstr = builder.create<shape::CstrBroadcastableOp>(it.loc, shapes);
189     witnesses.push_back(cstr);
190   }
191   if (witnesses.size() == 1) return witnesses.front();
192   return builder.create<shape::AssumingAllOp>(loc, witnesses);
193 }
194 
materializeBlockGlobalConstraintFusion(Location loc,OpBuilder & builder,Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)195 void materializeBlockGlobalConstraintFusion(
196     Location loc, OpBuilder &builder, Block *theBlock,
197     llvm::SmallSetVector<Operation *, 16> &toBeErased,
198     SmallVector<BlockArgument> &argumentCstrs,
199     SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
200   // Eliminate the old assuming regions and inline their ops into the main
201   // function body.
202   inlineAssumingRegions(theBlock);
203 
204   // Delete ops that are known to have become redundant by inlining of assuming
205   // regions.
206   for (auto *it : toBeErased) it->erase();
207 
208   // Materialize fused constraints at the beginning of the function.
209   builder.setInsertionPointToStart(theBlock);
210   Value fusedCstr = materializeFusedConstraints(loc, builder, argumentCstrs,
211                                                 broadcastabilityCstrs);
212 
213   // Create fused assuming region with empty body.
214   Operation *theBlockTerminator = theBlock->getTerminator();
215   auto fusedAop = builder.create<shape::AssumingOp>(
216       loc, theBlockTerminator->getOperandTypes(), fusedCstr);
217   auto *fusedAopBody = new Block;
218   fusedAop.getDoRegion().getBlocks().push_back(fusedAopBody);
219 
220   // Splice all the original block's operations into the fused assuming region's
221   // body (except for the block terminator).
222   auto &dstBlocks = fusedAopBody->getOperations();
223   dstBlocks.splice(dstBlocks.begin(), theBlock->getOperations(),
224                    builder.getInsertionPoint(),
225                    theBlockTerminator->getIterator());
226 
227   // Yield results from the assuming region and pass them on to the original
228   // block terminator.
229   builder.setInsertionPointToEnd(fusedAopBody);
230   builder.create<shape::AssumingYieldOp>(loc,
231                                          theBlockTerminator->getOperands());
232   theBlockTerminator->setOperands(fusedAop.getResults());
233 }
234 
isRemainingUse(OpOperand & use,Block * theBlock,llvm::SmallSetVector<Operation *,16> & considerDead)235 bool isRemainingUse(OpOperand &use, Block *theBlock,
236                     llvm::SmallSetVector<Operation *, 16> &considerDead) {
237   Operation *op = use.getOwner();
238 
239   // Not a real use if user is considered dead.
240   if (considerDead.count(op)) return false;
241 
242   // Assuming regions in the regarded block are not a real use as they will be
243   // inlined.
244   if (auto aop = llvm::dyn_cast<shape::AssumingOp>(op))
245     return aop->getBlock() == theBlock;
246 
247   // Look through assuming regions' yield ops.
248   if (auto yop = llvm::dyn_cast<shape::AssumingYieldOp>(op)) {
249     auto aop = yop->getParentOfType<shape::AssumingOp>();
250     auto outerResult = aop.getResults()[use.getOperandNumber()];
251     return llvm::all_of(outerResult.getUses(), [&](auto &outerUse) {
252       return isRemainingUse(outerUse, theBlock, considerDead);
253     });
254   }
255 
256   // Otherwise, consider it a real use.
257   return true;
258 }
259 
tryFlagForErase(Block * theBlock,Operation * op,llvm::SmallSetVector<Operation *,16> & toBeErased)260 void tryFlagForErase(Block *theBlock, Operation *op,
261                      llvm::SmallSetVector<Operation *, 16> &toBeErased) {
262   if (llvm::none_of(op->getUses(), [&](auto &use) {
263         return isRemainingUse(use, theBlock, toBeErased);
264       })) {
265     toBeErased.insert(op);
266   }
267 }
268 
isWithinBlock(Operation * op,Block * theBlock)269 bool isWithinBlock(Operation *op, Block *theBlock) {
270   while (op != nullptr && op->getBlock() != theBlock) op = op->getParentOp();
271   return op != nullptr;
272 }
273 
analyzeBroadcastableConstraint(shape::CstrBroadcastableOp cstrBcastable,Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<CstrBroadcastableOperand> & transitiveBcastableCstrOperands)274 LogicalResult analyzeBroadcastableConstraint(
275     shape::CstrBroadcastableOp cstrBcastable, Block *theBlock,
276     llvm::SmallSetVector<Operation *, 16> &toBeErased,
277     SmallVector<CstrBroadcastableOperand> &transitiveBcastableCstrOperands) {
278   SmallVector<Value> worklist = cstrBcastable.getShapes();
279   while (!worklist.empty()) {
280     Value shape = worklist.pop_back_val();
281     Operation *def = shape.getDefiningOp();
282 
283     // For shapes without a definition, expect them to be an argument of the
284     // regarded block.
285     if (def == nullptr) {
286       auto barg = shape.dyn_cast<BlockArgument>();
287       if (!barg || barg.getParentBlock() != theBlock) return failure();
288       transitiveBcastableCstrOperands.push_back(
289           CstrBroadcastableOperand::valueOf(barg));
290       continue;
291     }
292 
293     // For shape_of ops, expect them to wrap an argument of the regarded block.
294     // The shape reification pass helps achieve this, which should be run before
295     // this pass.
296     if (auto sof = llvm::dyn_cast<shape::ShapeOfOp>(def)) {
297       if (!isWithinBlock(sof, theBlock)) return failure();
298       tryFlagForErase(theBlock, def, toBeErased);
299       auto barg = sof.getArg().dyn_cast<BlockArgument>();
300       if (!barg) return failure();
301       transitiveBcastableCstrOperands.push_back(
302           CstrBroadcastableOperand::shapeOf(barg));
303       continue;
304     }
305 
306     // For broadcast ops, broadcastability of the operands is an implicit
307     // requirement. We can online the operands.
308     if (auto bcast = llvm::dyn_cast<shape::BroadcastOp>(def)) {
309       if (!isWithinBlock(bcast, theBlock)) return failure();
310       tryFlagForErase(theBlock, def, toBeErased);
311       auto bcastShapes = bcast.getShapes();
312       worklist.append(bcastShapes.begin(), bcastShapes.end());
313       continue;
314     }
315 
316     // Look into assuming ops to proceed.
317     if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
318       if (!isWithinBlock(aop, theBlock)) return failure();
319       auto yieldOp =
320           llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
321       size_t i = llvm::find(aop.getResults(), shape).getIndex();
322       Value innerShape = yieldOp.getOperand(i);
323       worklist.push_back(innerShape);
324       continue;
325     }
326 
327     // Otherwise, bail.
328     return failure();
329   }
330 
331   return success();
332 }
333 
analyzeBlockGlobalConstraints(Block * theBlock,llvm::SmallSetVector<Operation *,16> & toBeErased,SmallVector<BlockArgument> & argumentCstrs,SmallVector<CstrBroadcastableIntent> & broadcastabilityCstrs)334 LogicalResult analyzeBlockGlobalConstraints(
335     Block *theBlock, llvm::SmallSetVector<Operation *, 16> &toBeErased,
336     SmallVector<BlockArgument> &argumentCstrs,
337     SmallVector<CstrBroadcastableIntent> &broadcastabilityCstrs) {
338   // Find all the assuming regions and start the search for reachable
339   // constraints from there.
340   SmallVector<Value> cstrWorklist;
341   theBlock->walk(
342       [&](shape::AssumingOp aop) { cstrWorklist.push_back(aop.getWitness()); });
343 
344   while (!cstrWorklist.empty()) {
345     Value cstr = cstrWorklist.pop_back_val();
346     Operation *def = cstr.getDefiningOp();
347 
348     // For witnesses without a definition, expect it to be an argument of the
349     // regarded block.
350     if (def == nullptr) {
351       auto barg = cstr.dyn_cast<BlockArgument>();
352       if (!barg || barg.getParentBlock() != theBlock) return failure();
353       argumentCstrs.push_back(barg);
354       continue;
355     }
356 
357     // For conjunctions, continue with the operands.
358     if (auto aaop = llvm::dyn_cast<shape::AssumingAllOp>(def)) {
359       if (!isWithinBlock(aaop, theBlock)) return failure();
360       tryFlagForErase(theBlock, def, toBeErased);
361       auto aaopCstrs = aaop.getOperands();
362       cstrWorklist.append(aaopCstrs.begin(), aaopCstrs.end());
363       continue;
364     }
365 
366     // For broadcastable constraints, find the transitively included shape
367     // operands.
368     if (auto cstrBcastable = llvm::dyn_cast<shape::CstrBroadcastableOp>(def)) {
369       if (!isWithinBlock(cstrBcastable, theBlock)) return failure();
370       tryFlagForErase(theBlock, def, toBeErased);
371       CstrBroadcastableIntent bcastableIntent(cstrBcastable.getLoc());
372       if (failed(analyzeBroadcastableConstraint(
373               cstrBcastable, theBlock, toBeErased, bcastableIntent.operands))) {
374         return failure();
375       }
376       broadcastabilityCstrs.push_back(bcastableIntent);
377       continue;
378     }
379 
380     // Look into assuming regions when running into them. They will be inlined
381     // later.
382     if (auto aop = llvm::dyn_cast<shape::AssumingOp>(def)) {
383       if (!isWithinBlock(aop, theBlock)) return failure();
384       size_t i = llvm::find(aop.getResults(), cstr).getIndex();
385       auto yieldOp =
386           llvm::cast<shape::AssumingYieldOp>(aop.getBody()->getTerminator());
387       cstrWorklist.push_back(yieldOp.getOperand(i));
388       continue;
389     }
390 
391     // Otherwise, bail.
392     return failure();
393   }
394 
395   return success();
396 }
397 
fuseBlockGlobalConstraints(Location loc,OpBuilder & builder,Block * theBlock)398 LogicalResult fuseBlockGlobalConstraints(Location loc, OpBuilder &builder,
399                                          Block *theBlock) {
400   // Analyze block-global constraints.
401   SmallVector<BlockArgument> argumentCstrs;
402   SmallVector<CstrBroadcastableIntent> broadcastabilityCstrs;
403   llvm::SmallSetVector<Operation *, 16> toBeErased;
404   if (failed(analyzeBlockGlobalConstraints(theBlock, toBeErased, argumentCstrs,
405                                            broadcastabilityCstrs))) {
406     return failure();
407   }
408 
409   // Return early if there is nothing to do.
410   if (argumentCstrs.empty() && broadcastabilityCstrs.empty()) {
411     return success();
412   }
413 
414   // Simplify constraints.
415   eliminateDuplicateBlockArguments(argumentCstrs);
416   canonicalizeBroadcastabilityCstrs(broadcastabilityCstrs);
417 
418   // Materialize constraint fusion.
419   materializeBlockGlobalConstraintFusion(loc, builder, theBlock, toBeErased,
420                                          argumentCstrs, broadcastabilityCstrs);
421 
422   return success();
423 }
424 
425 struct ConstraintFusionPass
426     : public ConstraintFusionPassBase<ConstraintFusionPass> {
getDependentDialectsmlir::mhlo::__anon08b4e7980111::ConstraintFusionPass427   void getDependentDialects(DialectRegistry &registry) const override {
428     registry.insert<shape::ShapeDialect>();
429   }
430 
runOnOperationmlir::mhlo::__anon08b4e7980111::ConstraintFusionPass431   void runOnOperation() override {
432     func::FuncOp f = getOperation();
433     auto loc = f.getLoc();
434     OpBuilder builder(&getContext());
435     for (auto &block : f.getBody().getBlocks()) {
436       if (failed(fuseBlockGlobalConstraints(loc, builder, &block)))
437         return signalPassFailure();
438     }
439   }
440 };
441 
442 }  // namespace
443 
createConstraintFusionPass()444 std::unique_ptr<OperationPass<func::FuncOp>> createConstraintFusionPass() {
445   return std::make_unique<ConstraintFusionPass>();
446 }
447 
448 }  // namespace mhlo
449 }  // namespace mlir
450