1 #include <torch/csrc/jit/codegen/onednn/decompose_silu.h> 2 #include <torch/csrc/jit/codegen/onednn/operator.h> 3 4 #include <ATen/code_template.h> 5 #include <torch/csrc/jit/passes/dead_code_elimination.h> 6 #include <torch/csrc/jit/passes/subgraph_rewrite.h> 7 8 namespace torch { 9 namespace jit { 10 namespace fuser { 11 namespace onednn { 12 shouldDecomposeSilu(Node * node)13static bool shouldDecomposeSilu(Node* node) { 14 if (node->kind() != aten::silu) { 15 return false; 16 } 17 auto inputToSilu = node->input(0)->node(); 18 if (inputToSilu->kind() == aten::_convolution) { 19 // TODO: remove transpose check once the bridge supported ConvTranspose 20 bool transposed = Operator::Bool(inputToSilu, 6); 21 return !transposed; 22 } 23 if (inputToSilu->kind() == aten::linear) { 24 return true; 25 } 26 return false; 27 } 28 DecomposeSilu(Node * node)29static void DecomposeSilu(Node* node) { 30 if (shouldDecomposeSilu(node)) { 31 auto dtype = node->input(0)->type()->expect<TensorType>(); 32 33 WithInsertPoint guard(node); 34 auto g = node->owningGraph(); 35 auto sigmoid = g->insert(aten::sigmoid, {node->input(0)}); 36 sigmoid->setType(dtype); 37 38 auto mul = g->insert(aten::mul, {sigmoid, node->input(0)}); 39 mul->setType(dtype); 40 41 node->output()->replaceAllUsesWith(mul); 42 } 43 } 44 DecomposeSilu(Block * block)45static void DecomposeSilu(Block* block) { 46 for (auto node : block->nodes()) { 47 for (auto sub : node->blocks()) { 48 DecomposeSilu(sub); 49 } 50 51 if (node->kind() == aten::silu) { 52 DecomposeSilu(node); 53 } 54 } 55 } 56 DecomposeSiluForLLGA(std::shared_ptr<Graph> & graph)57void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) { 58 DecomposeSilu(graph->block()); 59 EliminateDeadCode(graph); 60 } 61 62 } // namespace onednn 63 } // namespace fuser 64 } // namespace jit 65 } // namespace torch 66