1 #include <gtest/gtest.h>
2
3 #include <test/cpp/jit/test_utils.h>
4
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/passes/peephole.h>
8
9 namespace torch {
10 namespace jit {
11
TEST(PeepholeOptimizeTest,IsAndIsNot)12 TEST(PeepholeOptimizeTest, IsAndIsNot)
13 // test is / is not none optimization
14 {
15 auto graph = std::make_shared<Graph>();
16 parseIR(
17 R"IR(
18 graph(%0 : int):
19 %1 : None = prim::Constant()
20 %2 : bool = aten::__is__(%0, %1)
21 %3 : bool = aten::__isnot__(%0, %1)
22 return (%2, %3)
23 )IR",
24 graph.get());
25 PeepholeOptimize(graph);
26 testing::FileCheck()
27 .check_not("aten::__is__")
28 ->check_not("aten::__isnot__")
29 ->run(*graph);
30 }
31
TEST(PeepholeOptimizeTest,IsAndIsNot2)32 TEST(PeepholeOptimizeTest, IsAndIsNot2) {
33 auto graph = std::make_shared<Graph>();
34 parseIR(
35 R"IR(
36 graph(%0: int?):
37 %1 : None = prim::Constant()
38 %2 : bool = aten::__is__(%0, %1)
39 %3 : bool = aten::__isnot__(%0, %1)
40 return (%2, %3)
41 )IR",
42 graph.get());
43 PeepholeOptimize(graph);
44 testing::FileCheck()
45 .check("aten::__is__")
46 ->check("aten::__isnot__")
47 ->run(*graph);
48 }
49
TEST(PeepholeOptimizeTest,IsAndIsNot3)50 TEST(PeepholeOptimizeTest, IsAndIsNot3) {
51 auto graph = std::make_shared<Graph>();
52 parseIR(
53 R"IR(
54 graph(%0: int?):
55 %1 : Tensor = prim::AutogradZero()
56 %2 : None = prim::Constant()
57 %4 : bool = aten::__is__(%0, %1)
58 %5 : bool = aten::__isnot__(%1, %2)
59 return (%4, %5)
60 )IR",
61 graph.get());
62 PeepholeOptimize(graph);
63 testing::FileCheck()
64 .check("aten::__is__")
65 ->check_not("aten::__isnot__")
66 ->run(*graph);
67 }
68
TEST(PeepholeOptimizeTest,UnwrapOptional)69 TEST(PeepholeOptimizeTest, UnwrapOptional)
70 // test unwrap optional
71 {
72 auto graph = std::make_shared<Graph>();
73 parseIR(
74 R"IR(
75 graph():
76 %1 : Float(*, *, *) = prim::Constant()
77 %2 : bool = aten::_unwrap_optional(%1)
78 %3 : bool = prim::unchecked_unwrap_optional(%1)
79 return (%2, %3)
80 )IR",
81 graph.get());
82 PeepholeOptimize(graph);
83 testing::FileCheck().check_not("unwrap")->run(*graph);
84 }
85
TEST(PeepholeOptimizeTest,UnwrapOptional2)86 TEST(PeepholeOptimizeTest, UnwrapOptional2) {
87 auto graph = std::make_shared<Graph>();
88 parseIR(
89 R"IR(
90 graph(%1 : Float(*, *, *)?):
91 %2 : bool = aten::_unwrap_optional(%1)
92 %3 : bool = prim::unchecked_unwrap_optional(%1)
93 return (%2, %3)
94 )IR",
95 graph.get());
96 PeepholeOptimize(graph);
97 testing::FileCheck().check_count("unwrap", 2)->run(*graph);
98 }
99
TEST(PeepholeOptimizeTest,AddMMFusion)100 TEST(PeepholeOptimizeTest, AddMMFusion) {
101 auto graph = std::make_shared<Graph>();
102 parseIR(
103 R"IR(
104 graph(
105 %0 : Float(2, 3, 4),
106 %1 : Float(2, 3, 4),
107 %2 : Float(1, 1, 1)):
108 %3 : int = prim::Constant[value=1]()
109 %4 : Tensor = aten::mm(%0, %1)
110 %5 : Tensor = aten::add(%4, %2, %3)
111 %6 : Tensor = aten::add(%5, %2, %3)
112 return (%6)
113 )IR",
114 graph.get());
115 FuseAddMM(graph);
116 testing::FileCheck().check("addmm")->run(*graph);
117 }
118 } // namespace jit
119 } // namespace torch
120