xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_dce.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/testing/file_check.h>
6 
7 namespace torch {
8 namespace jit {
TEST(EliminateDeadCodeTest,Basic)9 TEST(EliminateDeadCodeTest, Basic) {
10   auto graph = std::make_shared<Graph>();
11 
12   // Consider the following loop:
13   //   for i in range(3):
14   //     tot += a[0][0]
15   //     b = a[0]
16   //     b[0] += 1
17   //   print(tot)
18   // We want to check that b[0] and b are properly marked as live and thus not
19   // DCE'd.
20   const std::string input =
21       R"IR(
22 graph():
23   %48 : None = prim::Constant()
24   %50 : bool = prim::Constant[value=1]()
25   %0 : int = prim::Constant[value=2]()
26   %12 : int = prim::Constant[value=1]()
27   %24 : int = prim::Constant[value=3]()
28   %31 : int = prim::Constant[value=0]()
29   %2 : int[] = prim::ListConstruct(%0, %0)
30   %a.1 : Tensor = prim::MakeTestTensor()
31   %14 : int[] = prim::ListConstruct(%12)
32   %tot.1 : Tensor = prim::MakeTestTensor()
33   %tot : Tensor = prim::Loop(%24, %50, %tot.1)
34     block0(%i : int, %tot.6 : Tensor):
35       %33 : Tensor = aten::select(%a.1, %31, %31)
36       %35 : Tensor = aten::select(%33, %31, %31)
37       # CHECK: add_
38       %tot.3 : Tensor = aten::add_(%tot.6, %35, %12)
39       %b.1 : Tensor = aten::select(%a.1, %31, %31)
40       %44 : Tensor = aten::select(%b.1, %31, %31)
41       # CHECK: add_
42       %46 : Tensor = aten::add_(%44, %12, %12)
43       -> (%50, %tot.3)
44   return (%tot)
45 )IR";
46   parseIR(input, graph.get());
47   EliminateDeadCode(graph);
48   // Check that dead code elimin
49   testing::FileCheck().run(input, *graph);
50 }
51 } // namespace jit
52 } // namespace torch
53