xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_peephole_optimize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/passes/peephole.h>
8 
9 namespace torch {
10 namespace jit {
11 
TEST(PeepholeOptimizeTest,IsAndIsNot)12 TEST(PeepholeOptimizeTest, IsAndIsNot)
13 // test is / is not none optimization
14 {
15   auto graph = std::make_shared<Graph>();
16   parseIR(
17       R"IR(
18 graph(%0 : int):
19   %1 : None = prim::Constant()
20   %2 : bool = aten::__is__(%0, %1)
21   %3 : bool = aten::__isnot__(%0, %1)
22   return (%2, %3)
23   )IR",
24       graph.get());
25   PeepholeOptimize(graph);
26   testing::FileCheck()
27       .check_not("aten::__is__")
28       ->check_not("aten::__isnot__")
29       ->run(*graph);
30 }
31 
TEST(PeepholeOptimizeTest,IsAndIsNot2)32 TEST(PeepholeOptimizeTest, IsAndIsNot2) {
33   auto graph = std::make_shared<Graph>();
34   parseIR(
35       R"IR(
36 graph(%0: int?):
37   %1 : None = prim::Constant()
38   %2 : bool = aten::__is__(%0, %1)
39   %3 : bool = aten::__isnot__(%0, %1)
40   return (%2, %3)
41   )IR",
42       graph.get());
43   PeepholeOptimize(graph);
44   testing::FileCheck()
45       .check("aten::__is__")
46       ->check("aten::__isnot__")
47       ->run(*graph);
48 }
49 
TEST(PeepholeOptimizeTest,IsAndIsNot3)50 TEST(PeepholeOptimizeTest, IsAndIsNot3) {
51   auto graph = std::make_shared<Graph>();
52   parseIR(
53       R"IR(
54 graph(%0: int?):
55   %1 : Tensor = prim::AutogradZero()
56   %2 : None = prim::Constant()
57   %4 : bool = aten::__is__(%0, %1)
58   %5 : bool = aten::__isnot__(%1, %2)
59   return (%4, %5)
60   )IR",
61       graph.get());
62   PeepholeOptimize(graph);
63   testing::FileCheck()
64       .check("aten::__is__")
65       ->check_not("aten::__isnot__")
66       ->run(*graph);
67 }
68 
TEST(PeepholeOptimizeTest,UnwrapOptional)69 TEST(PeepholeOptimizeTest, UnwrapOptional)
70 // test unwrap optional
71 {
72   auto graph = std::make_shared<Graph>();
73   parseIR(
74       R"IR(
75 graph():
76   %1 : Float(*, *, *) = prim::Constant()
77   %2 : bool = aten::_unwrap_optional(%1)
78   %3 : bool = prim::unchecked_unwrap_optional(%1)
79   return (%2, %3)
80   )IR",
81       graph.get());
82   PeepholeOptimize(graph);
83   testing::FileCheck().check_not("unwrap")->run(*graph);
84 }
85 
TEST(PeepholeOptimizeTest,UnwrapOptional2)86 TEST(PeepholeOptimizeTest, UnwrapOptional2) {
87   auto graph = std::make_shared<Graph>();
88   parseIR(
89       R"IR(
90 graph(%1 : Float(*, *, *)?):
91   %2 : bool = aten::_unwrap_optional(%1)
92   %3 : bool = prim::unchecked_unwrap_optional(%1)
93   return (%2, %3)
94   )IR",
95       graph.get());
96   PeepholeOptimize(graph);
97   testing::FileCheck().check_count("unwrap", 2)->run(*graph);
98 }
99 
TEST(PeepholeOptimizeTest,AddMMFusion)100 TEST(PeepholeOptimizeTest, AddMMFusion) {
101   auto graph = std::make_shared<Graph>();
102   parseIR(
103       R"IR(
104       graph(
105         %0 : Float(2, 3, 4),
106         %1 : Float(2, 3, 4),
107         %2 : Float(1, 1, 1)):
108         %3 : int = prim::Constant[value=1]()
109         %4 : Tensor = aten::mm(%0, %1)
110         %5 : Tensor = aten::add(%4, %2, %3)
111         %6 : Tensor = aten::add(%5, %2, %3)
112         return (%6)
113         )IR",
114       graph.get());
115   FuseAddMM(graph);
116   testing::FileCheck().check("addmm")->run(*graph);
117 }
118 } // namespace jit
119 } // namespace torch
120