xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_cleanup_passes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/testing/file_check.h>
7 
8 namespace torch {
9 namespace jit {
10 
TEST(CleanupPassTest,Basic)11 TEST(CleanupPassTest, Basic) {
12   // Tests stability of clean up passes when dealing with constant pooling
13   // and constant propagation.
14   auto graph = std::make_shared<Graph>();
15   parseIR(
16       R"IR(
17 graph(%cond.1 : Tensor,
18       %suffix.1 : str):
19   %3 : bool = aten::Bool(%cond.1) # o.py:6:7
20   %25 : str = prim::If(%3) # o.py:6:4
21     block0():
22       %a.1 : str = prim::Constant[value="same string"]()
23       %b.1 : str = prim::Constant[value=" with a twist"]()
24       %7 : str = aten::add(%a.1, %b.1)
25       %11 : str = aten::add(%7, %suffix.1) # o.py:10:15
26       -> (%11)
27     block1():
28       %c.1 : str = prim::Constant[value="same string"]()
29       %d.1 : str = prim::Constant[value=" with a twist"]()
30       %12 : str = aten::add(%c.1, %d.1)
31       -> (%12)
32   return (%25)
33   )IR",
34       &*graph);
35   runCleanupPasses(graph);
36   testing::FileCheck()
37       .check_count(
38           "prim::Constant[value=\"same string with a twist\"]",
39           1,
40           /*exactly=*/true)
41       ->run(*graph);
42 
43   auto graph_after_pass_once = graph->toString();
44   runCleanupPasses(graph);
45   auto graph_after_pass_twice = graph->toString();
46   ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
47 }
48 } // namespace jit
49 } // namespace torch
50