1 #include <torch/csrc/jit/codegen/onednn/guard_shape.h> 2 3 #include <torch/csrc/jit/jit_log.h> 4 #include <torch/csrc/jit/passes/tensorexpr_fuser.h> 5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 6 #include <torch/csrc/jit/runtime/graph_executor.h> 7 8 namespace torch { 9 namespace jit { 10 namespace fuser { 11 namespace onednn { 12 13 //! [ Note -- prepareFusionGroupAndGuardOutputs implementation ] 14 //! shamelessly copying code from NNC (tensorexpr_fuser) with very little 15 //! modification, original code at: 16 //! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs` 17 //! 18 //! We have the assumption that LLGA does not have operators 19 //! depending on the content of the tensor. prepareFusionGroupAndGuardOutputs(Block * block)20void prepareFusionGroupAndGuardOutputs(Block* block) { 21 std::vector<Node*> fusion_groups; 22 for (Node* n : block->nodes()) { 23 for (Block* b : n->blocks()) { 24 prepareFusionGroupAndGuardOutputs(b); 25 } 26 if (n->kind() == prim::oneDNNFusionGroup) { 27 fusion_groups.push_back(n); 28 } 29 } 30 for (Node* fusion_group : fusion_groups) { 31 // TODO: add further optimization pass to removeOutputsUsedOnlyInSize, 32 // refer to 33 // `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize` 34 // removeOutputsUsedOnlyInSize(fusion_group); 35 insertTypeGuard( 36 fusion_group, 37 [](const TensorTypePtr& t) { return t; }, 38 prim::oneDNNFusionGuard); 39 } 40 } 41 42 } // namespace onednn 43 } // namespace fuser 44 } // namespace jit 45 } // namespace torch 46