xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/guard_shape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/guard_shape.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
6 #include <torch/csrc/jit/runtime/graph_executor.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
13 //! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
14 //! shamelessly copying code from NNC (tensorexpr_fuser)  with very little
15 //! modification, original code at:
16 //! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
17 //!
18 //! We have the assumption that LLGA does not have operators
19 //! depending on the content of the tensor.
prepareFusionGroupAndGuardOutputs(Block * block)20 void prepareFusionGroupAndGuardOutputs(Block* block) {
21   std::vector<Node*> fusion_groups;
22   for (Node* n : block->nodes()) {
23     for (Block* b : n->blocks()) {
24       prepareFusionGroupAndGuardOutputs(b);
25     }
26     if (n->kind() == prim::oneDNNFusionGroup) {
27       fusion_groups.push_back(n);
28     }
29   }
30   for (Node* fusion_group : fusion_groups) {
31     // TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
32     // refer to
33     // `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
34     // removeOutputsUsedOnlyInSize(fusion_group);
35     insertTypeGuard(
36         fusion_group,
37         [](const TensorTypePtr& t) { return t; },
38         prim::oneDNNFusionGuard);
39   }
40 }
41 
42 } // namespace onednn
43 } // namespace fuser
44 } // namespace jit
45 } // namespace torch
46