xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fuse_linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/fuse_linear.h>
2 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
3 #include <torch/csrc/jit/passes/quantization/helper.h>
4 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
5 
6 namespace torch::jit {
7 
FuseLinear(std::shared_ptr<Graph> & graph)8 void FuseLinear(std::shared_ptr<Graph>& graph) {
9   std::string addmm_pattern = R"IR(
10     graph(%input, %weight_t, %bias, %beta, %alpha):
11         %res = aten::addmm(%bias, %input, %weight_t, %beta, %alpha)
12         return (%res))IR";
13   std::string fused_linear_addmm = R"IR(
14     graph(%input, %weight_t, %bias, %beta, %alpha):
15         %weight = aten::t(%weight_t)
16         %res = aten::linear(%input, %weight, %bias)
17         return (%res))IR";
18 
19   auto beta_is_one = [](const Match& match,
20                         const std::unordered_map<std::string, Value*>& vmap) {
21     return is_int_constant(match, vmap, "beta", 1);
22   };
23 
24   // check %weight_t is produced by `aten::t` to make sure
25   // we can transform the pattern to `aten::linear`
26   auto weight_transposed =
27       [](const Match& match,
28          const std::unordered_map<std::string, Value*>& vmap) {
29         const auto& match_vmap = match.values_map;
30         auto v = match_vmap.at(vmap.at("weight_t"));
31         return v->node()->kind() == Symbol::aten("t");
32       };
33 
34   // replace addmm pattern to linear
35   SubgraphRewriter addmm_to_linear;
36   std::vector<std::pair<std::string, std::string>> value_mappings(
37       {{"weight", "res"}, {"res", "res"}});
38   addmm_to_linear.RegisterRewritePattern(
39       addmm_pattern, fused_linear_addmm, value_mappings);
40   addmm_to_linear.runOnGraph(
41       graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed});
42 
43   std::string matmul_add_pattern = R"IR(
44     graph(%input, %weight_t, %bias, %alpha):
45         %output = aten::matmul(%input, %weight_t)
46         %res = aten::add_(%output, %bias, %alpha)
47         return (%res))IR";
48   std::string fused_linear_matmul = R"IR(
49     graph(%input, %weight_t, %bias, %alpha):
50         %weight = aten::t(%weight_t)
51         %res = aten::linear(%input, %weight, %bias)
52         return (%res))IR";
53   value_mappings = {{"weight", "output"}, {"res", "output"}};
54   // replace matmul + add pattern to linear
55   SubgraphRewriter matmuladd_to_linear;
56   matmuladd_to_linear.RegisterRewritePattern(
57       matmul_add_pattern, fused_linear_matmul, value_mappings);
58   matmuladd_to_linear.runOnGraph(
59       graph, {aten_add_alpha_is_one, weight_transposed});
60 
61   std::string matmul_pattern = R"IR(
62     graph(%input, %weight_t):
63         %output = aten::matmul(%input, %weight_t)
64         return (%output))IR";
65   std::string fused_linear_bias_none = R"IR(
66     graph(%input, %weight_t):
67         %weight = aten::t(%weight_t)
68         %bias: Tensor? = prim::Constant()
69         %res = aten::linear(%input, %weight, %bias)
70         return (%res))IR";
71 
72   // replace matmul with bias=None pattern to linear
73   SubgraphRewriter matmul_to_linear;
74   matmul_to_linear.RegisterRewritePattern(
75       matmul_pattern, fused_linear_bias_none, value_mappings);
76   matmul_to_linear.runOnGraph(graph, weight_transposed);
77 
78   // clean up extra transpose for the weight of aten::linear
79   std::string linear_weight_extra_transpose = R"IR(
80     graph(%input, %weight, %bias):
81         %weight_t1 = aten::t(%weight)
82         %weight_t2 = aten::t(%weight_t1)
83         %res = aten::linear(%input, %weight_t2, %bias)
84         return (%res))IR";
85 
86   std::string linear_weight_no_transpose = R"IR(
87     graph(%input, %weight, %bias):
88         %res = aten::linear(%input, %weight, %bias)
89         return (%res))IR";
90 
91   value_mappings = {{"res", "res"}};
92   SubgraphRewriter cleanup;
93   cleanup.RegisterRewritePattern(
94       linear_weight_extra_transpose,
95       linear_weight_no_transpose,
96       value_mappings);
97   cleanup.runOnGraph(graph);
98 
99   SwapFunctionalLinear(graph);
100 }
101 
SwapFunctionalLinear(Module & module)102 void SwapFunctionalLinear(Module& module) {
103   for (auto& method : module.get_methods()) {
104     std::shared_ptr<Graph> g = method.graph();
105     SwapFunctionalLinear(g);
106   }
107   for (Module m : module.children()) {
108     SwapFunctionalLinear(m);
109   }
110 }
111 
SwapFunctionalLinear(std::shared_ptr<Graph> & graph)112 void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) {
113   std::string functional_linear = R"(
114 graph(%linear, %input, %weight, %bias):
115   %r = prim::CallFunction(%linear, %input, %weight, %bias)
116   return (%r) )";
117   std::string aten_linear = R"(
118 graph(%linear, %input, %weight, %bias):
119   %r = aten::linear(%input, %weight, %bias)
120   return (%r) )";
121 
122   auto filter = [](const Match& match,
123                    const std::unordered_map<std::string, Value*>& vmap) {
124     const auto& match_vmap = match.values_map;
125     auto linear = graph_rewrite_helper::getValue("linear", match_vmap, vmap);
126     auto func_name = graph_rewrite_helper::getFuncName(linear);
127     return func_name == "linear";
128   };
129   SubgraphRewriter rewriter;
130   rewriter.RegisterRewritePattern(functional_linear, aten_linear);
131   rewriter.runOnGraph(graph, filter);
132 }
133 
134 } // namespace torch::jit
135