xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_op_replacement.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
5 #include <torch/csrc/jit/operator_upgraders/version_map.h>
6 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
7 #include <memory>
8 
9 namespace torch {
10 namespace jit {
11 
12 std::unordered_map<std::string, std::string> test_upgraders(
13     {{"_test_serialization_subcmul_0_2", R"IR(graph(%self.1 : Tensor,
14                                                     %other.1 : Tensor,
15                                                     %alpha.1 : Union(float, int)):
16                                                 %7 : int = prim::Constant[value=1]()
17                                                 %6 : Tensor = aten::mul(%self.1, %alpha.1) # torch/jit/operator_upgraders.py:18:20
18                                                 %8 : Tensor = aten::sub(%other.1, %6, %7) # torch/jit/operator_upgraders.py:18:11
19                                                 return (%8))IR"},
20      {"div_Tensor_0_3", R"IR(graph(%self.1 : Tensor,
21                                   %other.1 : Tensor):
22                             %32 : str = prim::Constant[value="trunc"]()
23                             %6 : bool = prim::Constant[value=1]()
24                             %4 : bool = aten::is_floating_point(%self.1)
25                             %11 : bool = prim::If(%4)
26                                 block0():
27                                     -> (%6)
28                                 block1():
29                                     %9 : bool = aten::is_floating_point(%other.1)
30                                     -> (%9)
31                             %35 : Tensor = prim::If(%11)
32                                 block0():
33                                     %36 : Tensor = aten::div(%self.1, %other.1)
34                                     -> (%36)
35                                 block1():
36                                     %37 : Tensor = aten::div(%self.1, %other.1, %32)
37                                     -> (%37)
38                             return (%35))IR"}});
39 
TEST(OpReplacementTest,ReplaceDivInSimpleFunction)40 TEST(OpReplacementTest, ReplaceDivInSimpleFunction) {
41   const auto graph_string = R"IR(
42         graph(%0 : Tensor,
43               %1 : Tensor):
44             %2 : Tensor = aten::add(%0, %1)
45             %3 : Tensor  = aten::div(%2, %1)
46             return (%3))IR";
47   auto g = std::make_shared<Graph>();
48   test_only_populate_upgraders(test_upgraders);
49   torch::jit::parseIR(graph_string, g.get());
50   g->set_op_version(2);
51   ReplaceOldOperatorsWithUpgraders(g);
52   testing::FileCheck()
53       .check("prim::If")
54       ->check_count("aten::div(%2, %1)", 1, /*exactly=*/true)
55       ->check_count("aten::div(%2, %1, %4)", 1, /*exactly=*/true)
56       ->run(*g);
57 }
58 
TEST(OpReplacementTest,ReplaceTwoOpsInSimpleFunction)59 TEST(OpReplacementTest, ReplaceTwoOpsInSimpleFunction) {
60   const auto graph_string = R"IR(
61         graph(%0 : Tensor,
62               %1 : Tensor):
63             %2 : Tensor = aten::add(%0, %1)
64             %3 : Tensor  = aten::div(%2, %1)
65             %4 : int = prim::Constant[value=1]()
66             %5: Tensor = aten::_test_serialization_subcmul(%0, %1, %4)
67             return (%3, %5))IR";
68   auto g = std::make_shared<Graph>();
69   test_only_populate_upgraders(test_upgraders);
70   UpgraderEntry test_entry{
71       3,
72       "_test_serialization_subcmul_0_2",
73       "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"};
74   test_only_add_entry("aten::_test_serialization_subcmul", test_entry);
75   torch::jit::parseIR(graph_string, g.get());
76   g->set_op_version(2);
77   ReplaceOldOperatorsWithUpgraders(g);
78   testing::FileCheck()
79       .check("prim::If")
80       ->check_count("aten::div", 2, /*exactly=*/true)
81       ->run(*g);
82   test_only_remove_entry("aten::_test_serialization_subcmul");
83   test_only_remove_upgraders(test_upgraders);
84 }
85 
TEST(OpReplacementTest,ReplaceDivInNestedFunction)86 TEST(OpReplacementTest, ReplaceDivInNestedFunction) {
87   const auto graph_string = R"IR(
88         graph(%0 : Tensor,
89               %1 : Tensor,
90               %8 : bool):
91             %9 : bool = prim::Constant[value=1]()
92             %7 : bool = prim::If(%8)
93                 block0():
94                     -> (%9)
95                 block1():
96                     %2 : Tensor = aten::add(%0, %1)
97                     %3 : Tensor  = aten::div(%2, %1)
98                     %4 : Tensor = aten::add(%3, %0)
99                     %10 : bool = aten::is_floating_point(%4)
100                     -> (%10)
101             return (%7))IR";
102   auto g = std::make_shared<Graph>();
103   test_only_populate_upgraders(test_upgraders);
104   torch::jit::parseIR(graph_string, g.get());
105   g->set_op_version(2);
106   ReplaceOldOperatorsWithUpgraders(g);
107   testing::FileCheck()
108       .check("prim::If")
109       ->check_count("aten::add", 2, false)
110       ->run(*g);
111 
112   testing::FileCheck()
113       .check("prim::If")
114       ->check_count("aten::div", 2, false)
115       ->run(*g);
116   test_only_remove_upgraders(test_upgraders);
117 }
118 
TEST(OpReplacementTest,ReplaceTestSubcmulInSimpleFunction)119 TEST(OpReplacementTest, ReplaceTestSubcmulInSimpleFunction) {
120   const auto graph_string = R"IR(
121         graph(%0 : Tensor,
122               %1 : Tensor):
123             %3 : int = prim::Constant[value=1]()
124             %2 : Tensor = aten::_test_serialization_subcmul(%0, %1, %3)
125             return (%2))IR";
126   auto g = std::make_shared<Graph>();
127   test_only_populate_upgraders(test_upgraders);
128   UpgraderEntry test_entry{
129       3,
130       "_test_serialization_subcmul_0_2",
131       "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"};
132   test_only_add_entry("aten::_test_serialization_subcmul", test_entry);
133   torch::jit::parseIR(graph_string, g.get());
134   g->set_op_version(2);
135   ReplaceOldOperatorsWithUpgraders(g);
136   testing::FileCheck().check_count("aten::mul", 1, false)->run(*g);
137 
138   testing::FileCheck().check_count("aten::sub", 1, false)->run(*g);
139 
140   test_only_remove_upgraders(test_upgraders);
141   test_only_remove_entry("aten::_test_serialization_subcmul");
142 }
143 
144 } // namespace jit
145 } // namespace torch
146