xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/decompose_silu.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
2 #include <torch/csrc/jit/codegen/onednn/operator.h>
3 
4 #include <ATen/code_template.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
shouldDecomposeSilu(Node * node)13 static bool shouldDecomposeSilu(Node* node) {
14   if (node->kind() != aten::silu) {
15     return false;
16   }
17   auto inputToSilu = node->input(0)->node();
18   if (inputToSilu->kind() == aten::_convolution) {
19     // TODO: remove transpose check once the bridge supported ConvTranspose
20     bool transposed = Operator::Bool(inputToSilu, 6);
21     return !transposed;
22   }
23   if (inputToSilu->kind() == aten::linear) {
24     return true;
25   }
26   return false;
27 }
28 
DecomposeSilu(Node * node)29 static void DecomposeSilu(Node* node) {
30   if (shouldDecomposeSilu(node)) {
31     auto dtype = node->input(0)->type()->expect<TensorType>();
32 
33     WithInsertPoint guard(node);
34     auto g = node->owningGraph();
35     auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
36     sigmoid->setType(dtype);
37 
38     auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
39     mul->setType(dtype);
40 
41     node->output()->replaceAllUsesWith(mul);
42   }
43 }
44 
DecomposeSilu(Block * block)45 static void DecomposeSilu(Block* block) {
46   for (auto node : block->nodes()) {
47     for (auto sub : node->blocks()) {
48       DecomposeSilu(sub);
49     }
50 
51     if (node->kind() == aten::silu) {
52       DecomposeSilu(node);
53     }
54   }
55 }
56 
DecomposeSiluForLLGA(std::shared_ptr<Graph> & graph)57 void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
58   DecomposeSilu(graph->block());
59   EliminateDeadCode(graph);
60 }
61 
62 } // namespace onednn
63 } // namespace fuser
64 } // namespace jit
65 } // namespace torch
66