xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_constant_pooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/testing/file_check.h>
8 
9 #include <sstream>
10 #include <string>
11 
12 namespace torch {
13 namespace jit {
14 
TEST(ConstantPoolingTest,Int)15 TEST(ConstantPoolingTest, Int) {
16   auto graph = std::make_shared<Graph>();
17   parseIR(
18       R"IR(
19 graph():
20   %8 : int = prim::Constant[value=1]()
21   %10 : int = prim::Constant[value=1]()
22   return (%8, %10)
23   )IR",
24       &*graph);
25   ConstantPooling(graph);
26   testing::FileCheck()
27       .check_count("prim::Constant", 1, /*exactly*/ true)
28       ->run(*graph);
29 }
30 
TEST(ConstantPoolingTest,PoolingAcrossBlocks)31 TEST(ConstantPoolingTest, PoolingAcrossBlocks) {
32   auto graph = std::make_shared<Graph>();
33   parseIR(
34       R"IR(
35 graph(%cond : Tensor):
36   %a : str = prim::Constant[value="bcd"]()
37   %3 : bool = aten::Bool(%cond)
38   %b : str = prim::If(%3)
39     block0():
40       %b.1 : str = prim::Constant[value="abc"]()
41       -> (%b.1)
42     block1():
43       %b.2 : str = prim::Constant[value="abc"]()
44       -> (%b.2)
45   %7 : (str, str) = prim::TupleConstruct(%a, %b)
46   return (%7)
47   )IR",
48       &*graph);
49   ConstantPooling(graph);
50   testing::FileCheck()
51       .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
52       ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
53       ->run(*graph);
54 }
55 
TEST(ConstantPoolingTest,PoolingDifferentDevices)56 TEST(ConstantPoolingTest, PoolingDifferentDevices) {
57   auto graph = std::make_shared<Graph>();
58   parseIR(
59       R"IR(
60 graph():
61   %2 : int = prim::Constant[value=2]()
62   %1 : int = prim::Constant[value=1]()
63   %5 : int? = prim::Constant()
64   %7 : Device? = prim::Constant()
65   %15: bool = prim::Constant[value=0]()
66   %10 : int = prim::Constant[value=6]()
67   %3 : int[] = prim::ListConstruct(%1, %2)
68   %x : Tensor = aten::tensor(%3, %5, %7, %15)
69   %y : Tensor = aten::tensor(%3, %10, %7, %15)
70   %9 : int[] = prim::ListConstruct(%1, %2)
71   %z : Tensor = aten::tensor(%9, %10, %7, %15)
72   prim::Print(%x, %y, %z)
73   return (%1)
74   )IR",
75       &*graph);
76   // three tensors created - two different devices among the three
77   // don't have good support for parsing tensor constants
78   ConstantPropagation(graph);
79   ConstantPooling(graph);
80   testing::FileCheck()
81       .check_count(
82           "Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant",
83           1,
84           /*exactly*/ true)
85       ->check_count(
86           "Long(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant",
87           1,
88           /*exactly*/ true)
89       ->run(*graph);
90 }
91 
TEST(ConstantPoolingTest,DictConstantPooling)92 TEST(ConstantPoolingTest, DictConstantPooling) {
93   auto graph = std::make_shared<Graph>();
94   parseIR(
95       R"IR(
96 graph():
97   %0 : int = prim::Constant[value=1]() # test/elias.py:6:9
98   %1 : int = prim::Constant[value=2]() # test/elias.py:6:12
99   %a.1 : Dict(int, int) = prim::DictConstruct(%0, %1)
100   %b.1 : Dict(int, int) = prim::DictConstruct(%1, %1)
101   return (%a.1, %b.1)
102   )IR",
103       &*graph);
104   ConstantPropagation(graph);
105   ConstantPooling(graph);
106   testing::FileCheck()
107       .check_count(
108           "Dict(int, int) = prim::Constant",
109           2,
110           /*exactly*/ true)
111       ->run(*graph);
112 }
113 } // namespace jit
114 } // namespace torch
115