1 #include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
4
5 namespace torch {
6 namespace jit {
7 namespace fuser {
8 namespace onednn {
9
10 class SizeCheckMover {
11 private:
12 Block* block_;
13 std::shared_ptr<Graph> graph_;
14
15 public:
SizeCheckMover(Block * block,std::shared_ptr<Graph> graph)16 SizeCheckMover(Block* block, std::shared_ptr<Graph> graph)
17 : block_(block), graph_(std::move(graph)) {}
18
analyzeNode(Node * node,AliasDb & aliasDb)19 bool analyzeNode(Node* node, AliasDb& aliasDb) {
20 //
21 // %b = addmm(%a)
22 // %sz = aten::size(%b)
23 // %c = relu(%b)
24 // =>
25 // %b = addmm(%a)
26 // %c = relu(%b)
27 // %sz = aten::size(%c)
28 // ^-- move size check after relu as it preserves input shape
29 //
30 if (!node->matches("aten::size(Tensor self) -> int[]"))
31 return false;
32
33 auto* input = node->input(0);
34 auto& uses = input->uses();
35 bool onlyUsedByShapePreserveOp =
36 uses.size() > 1 && std::all_of(uses.begin(), uses.end(), [&](auto& u) {
37 if (u.user == node) {
38 return true;
39 }
40 // match with shape-preserving unary ops in
41 // tensorexpr_elementwise_set that's defined in
42 // torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp
43 OperatorMap<std::string> schemaMap = get_tensorexpr_elementwise_set();
44 std::optional<std::string> mapping =
45 schemaMap.find(u.user->getOperator());
46 return mapping == "unary";
47 });
48
49 if (!onlyUsedByShapePreserveOp)
50 return false;
51
52 for (const auto& use : uses) {
53 if (use.user == node)
54 continue;
55 auto shapePreserveOp = use.user;
56 if (aliasDb.moveAfterTopologicallyValid(node, shapePreserveOp)) {
57 node->replaceInputWith(input, shapePreserveOp->output(0));
58 return true;
59 }
60 }
61
62 return false;
63 }
64
run()65 void run() {
66 bool changed = true;
67 while (changed) {
68 changed = false;
69 AliasDb aliasDb(graph_);
70 for (Node* node : block_->nodes()) {
71 changed |= analyzeNode(node, aliasDb);
72 }
73 }
74
75 for (Node* node : block_->nodes())
76 for (Block* subBlock : node->blocks())
77 SizeCheckMover(subBlock, graph_).run();
78 }
79 };
80
DeferSizeCheck(std::shared_ptr<Graph> & graph)81 void DeferSizeCheck(std::shared_ptr<Graph>& graph) {
82 SizeCheckMover(graph->block(), graph).run();
83 }
84
85 } // namespace onednn
86 } // namespace fuser
87 } // namespace jit
88 } // namespace torch
89