1 #include <torch/csrc/jit/passes/fuse_linear.h>
2 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
3 #include <torch/csrc/jit/passes/quantization/helper.h>
4 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
5
6 namespace torch::jit {
7
FuseLinear(std::shared_ptr<Graph> & graph)8 void FuseLinear(std::shared_ptr<Graph>& graph) {
9 std::string addmm_pattern = R"IR(
10 graph(%input, %weight_t, %bias, %beta, %alpha):
11 %res = aten::addmm(%bias, %input, %weight_t, %beta, %alpha)
12 return (%res))IR";
13 std::string fused_linear_addmm = R"IR(
14 graph(%input, %weight_t, %bias, %beta, %alpha):
15 %weight = aten::t(%weight_t)
16 %res = aten::linear(%input, %weight, %bias)
17 return (%res))IR";
18
19 auto beta_is_one = [](const Match& match,
20 const std::unordered_map<std::string, Value*>& vmap) {
21 return is_int_constant(match, vmap, "beta", 1);
22 };
23
24 // check %weight_t is produced by `aten::t` to make sure
25 // we can transform the pattern to `aten::linear`
26 auto weight_transposed =
27 [](const Match& match,
28 const std::unordered_map<std::string, Value*>& vmap) {
29 const auto& match_vmap = match.values_map;
30 auto v = match_vmap.at(vmap.at("weight_t"));
31 return v->node()->kind() == Symbol::aten("t");
32 };
33
34 // replace addmm pattern to linear
35 SubgraphRewriter addmm_to_linear;
36 std::vector<std::pair<std::string, std::string>> value_mappings(
37 {{"weight", "res"}, {"res", "res"}});
38 addmm_to_linear.RegisterRewritePattern(
39 addmm_pattern, fused_linear_addmm, value_mappings);
40 addmm_to_linear.runOnGraph(
41 graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed});
42
43 std::string matmul_add_pattern = R"IR(
44 graph(%input, %weight_t, %bias, %alpha):
45 %output = aten::matmul(%input, %weight_t)
46 %res = aten::add_(%output, %bias, %alpha)
47 return (%res))IR";
48 std::string fused_linear_matmul = R"IR(
49 graph(%input, %weight_t, %bias, %alpha):
50 %weight = aten::t(%weight_t)
51 %res = aten::linear(%input, %weight, %bias)
52 return (%res))IR";
53 value_mappings = {{"weight", "output"}, {"res", "output"}};
54 // replace matmul + add pattern to linear
55 SubgraphRewriter matmuladd_to_linear;
56 matmuladd_to_linear.RegisterRewritePattern(
57 matmul_add_pattern, fused_linear_matmul, value_mappings);
58 matmuladd_to_linear.runOnGraph(
59 graph, {aten_add_alpha_is_one, weight_transposed});
60
61 std::string matmul_pattern = R"IR(
62 graph(%input, %weight_t):
63 %output = aten::matmul(%input, %weight_t)
64 return (%output))IR";
65 std::string fused_linear_bias_none = R"IR(
66 graph(%input, %weight_t):
67 %weight = aten::t(%weight_t)
68 %bias: Tensor? = prim::Constant()
69 %res = aten::linear(%input, %weight, %bias)
70 return (%res))IR";
71
72 // replace matmul with bias=None pattern to linear
73 SubgraphRewriter matmul_to_linear;
74 matmul_to_linear.RegisterRewritePattern(
75 matmul_pattern, fused_linear_bias_none, value_mappings);
76 matmul_to_linear.runOnGraph(graph, weight_transposed);
77
78 // clean up extra transpose for the weight of aten::linear
79 std::string linear_weight_extra_transpose = R"IR(
80 graph(%input, %weight, %bias):
81 %weight_t1 = aten::t(%weight)
82 %weight_t2 = aten::t(%weight_t1)
83 %res = aten::linear(%input, %weight_t2, %bias)
84 return (%res))IR";
85
86 std::string linear_weight_no_transpose = R"IR(
87 graph(%input, %weight, %bias):
88 %res = aten::linear(%input, %weight, %bias)
89 return (%res))IR";
90
91 value_mappings = {{"res", "res"}};
92 SubgraphRewriter cleanup;
93 cleanup.RegisterRewritePattern(
94 linear_weight_extra_transpose,
95 linear_weight_no_transpose,
96 value_mappings);
97 cleanup.runOnGraph(graph);
98
99 SwapFunctionalLinear(graph);
100 }
101
SwapFunctionalLinear(Module & module)102 void SwapFunctionalLinear(Module& module) {
103 for (auto& method : module.get_methods()) {
104 std::shared_ptr<Graph> g = method.graph();
105 SwapFunctionalLinear(g);
106 }
107 for (Module m : module.children()) {
108 SwapFunctionalLinear(m);
109 }
110 }
111
SwapFunctionalLinear(std::shared_ptr<Graph> & graph)112 void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) {
113 std::string functional_linear = R"(
114 graph(%linear, %input, %weight, %bias):
115 %r = prim::CallFunction(%linear, %input, %weight, %bias)
116 return (%r) )";
117 std::string aten_linear = R"(
118 graph(%linear, %input, %weight, %bias):
119 %r = aten::linear(%input, %weight, %bias)
120 return (%r) )";
121
122 auto filter = [](const Match& match,
123 const std::unordered_map<std::string, Value*>& vmap) {
124 const auto& match_vmap = match.values_map;
125 auto linear = graph_rewrite_helper::getValue("linear", match_vmap, vmap);
126 auto func_name = graph_rewrite_helper::getFuncName(linear);
127 return func_name == "linear";
128 };
129 SubgraphRewriter rewriter;
130 rewriter.RegisterRewritePattern(functional_linear, aten_linear);
131 rewriter.runOnGraph(graph, filter);
132 }
133
134 } // namespace torch::jit
135