xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/defer_size_check.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace fuser {
8 namespace onednn {
9 
10 class SizeCheckMover {
11  private:
12   Block* block_;
13   std::shared_ptr<Graph> graph_;
14 
15  public:
SizeCheckMover(Block * block,std::shared_ptr<Graph> graph)16   SizeCheckMover(Block* block, std::shared_ptr<Graph> graph)
17       : block_(block), graph_(std::move(graph)) {}
18 
analyzeNode(Node * node,AliasDb & aliasDb)19   bool analyzeNode(Node* node, AliasDb& aliasDb) {
20     //
21     // %b = addmm(%a)
22     // %sz = aten::size(%b)
23     // %c = relu(%b)
24     //  =>
25     // %b = addmm(%a)
26     // %c = relu(%b)
27     // %sz = aten::size(%c)
28     //       ^-- move size check after relu as it preserves input shape
29     //
30     if (!node->matches("aten::size(Tensor self) -> int[]"))
31       return false;
32 
33     auto* input = node->input(0);
34     auto& uses = input->uses();
35     bool onlyUsedByShapePreserveOp =
36         uses.size() > 1 && std::all_of(uses.begin(), uses.end(), [&](auto& u) {
37           if (u.user == node) {
38             return true;
39           }
40           // match with shape-preserving unary ops in
41           // tensorexpr_elementwise_set that's defined in
42           // torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp
43           OperatorMap<std::string> schemaMap = get_tensorexpr_elementwise_set();
44           std::optional<std::string> mapping =
45               schemaMap.find(u.user->getOperator());
46           return mapping == "unary";
47         });
48 
49     if (!onlyUsedByShapePreserveOp)
50       return false;
51 
52     for (const auto& use : uses) {
53       if (use.user == node)
54         continue;
55       auto shapePreserveOp = use.user;
56       if (aliasDb.moveAfterTopologicallyValid(node, shapePreserveOp)) {
57         node->replaceInputWith(input, shapePreserveOp->output(0));
58         return true;
59       }
60     }
61 
62     return false;
63   }
64 
run()65   void run() {
66     bool changed = true;
67     while (changed) {
68       changed = false;
69       AliasDb aliasDb(graph_);
70       for (Node* node : block_->nodes()) {
71         changed |= analyzeNode(node, aliasDb);
72       }
73     }
74 
75     for (Node* node : block_->nodes())
76       for (Block* subBlock : node->blocks())
77         SizeCheckMover(subBlock, graph_).run();
78   }
79 };
80 
DeferSizeCheck(std::shared_ptr<Graph> & graph)81 void DeferSizeCheck(std::shared_ptr<Graph>& graph) {
82   SizeCheckMover(graph->block(), graph).run();
83 }
84 
85 } // namespace onednn
86 } // namespace fuser
87 } // namespace jit
88 } // namespace torch
89