xref: /aosp_15_r20/external/pytorch/test/cpp/api/parameterdict.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
4*da0073e9SAndroid Build Coastguard Worker #include <memory>
5*da0073e9SAndroid Build Coastguard Worker #include <vector>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
10*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct ParameterDictTest : torch::test::SeedingFixture {};
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,ConstructFromTensor)14*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, ConstructFromTensor) {
15*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
16*da0073e9SAndroid Build Coastguard Worker   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
17*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
18*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tc = torch::randn({1, 2});
19*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(ta.requires_grad());
20*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(tb.requires_grad());
21*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", ta);
22*da0073e9SAndroid Build Coastguard Worker   dict->insert("B", tb);
23*da0073e9SAndroid Build Coastguard Worker   dict->insert("C", tc);
24*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
25*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
26*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict["A"].requires_grad());
27*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
28*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(dict["B"].requires_grad());
29*da0073e9SAndroid Build Coastguard Worker }
30*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,ConstructFromOrderedDict)31*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, ConstructFromOrderedDict) {
32*da0073e9SAndroid Build Coastguard Worker   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
33*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
34*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tc = torch::randn({1, 2});
35*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> params = {
36*da0073e9SAndroid Build Coastguard Worker       {"A", ta}, {"B", tb}, {"C", tc}};
37*da0073e9SAndroid Build Coastguard Worker   auto dict = torch::nn::ParameterDict(params);
38*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
39*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
40*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict["A"].requires_grad());
41*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
42*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(dict["B"].requires_grad());
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,InsertAndContains)45*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, InsertAndContains) {
46*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
47*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", torch::tensor({1.0}));
48*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 1);
49*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->contains("A"));
50*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(dict->contains("C"));
51*da0073e9SAndroid Build Coastguard Worker }
52*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,InsertAndClear)53*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, InsertAndClear) {
54*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
55*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", torch::tensor({1.0}));
56*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 1);
57*da0073e9SAndroid Build Coastguard Worker   dict->clear();
58*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 0);
59*da0073e9SAndroid Build Coastguard Worker }
60*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,InsertAndPop)61*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, InsertAndPop) {
62*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
63*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", torch::tensor({1.0}));
64*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 1);
65*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(dict->pop("B"), "Parameter 'B' is not defined");
66*da0073e9SAndroid Build Coastguard Worker   torch::Tensor p = dict->pop("A");
67*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 0);
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item<bool>());
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,SimpleUpdate)71*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, SimpleUpdate) {
72*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
73*da0073e9SAndroid Build Coastguard Worker   ParameterDict wrongDict;
74*da0073e9SAndroid Build Coastguard Worker   ParameterDict rightDict;
75*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", torch::tensor({1.0}));
76*da0073e9SAndroid Build Coastguard Worker   dict->insert("B", torch::tensor({2.0}));
77*da0073e9SAndroid Build Coastguard Worker   dict->insert("C", torch::tensor({3.0}));
78*da0073e9SAndroid Build Coastguard Worker   wrongDict->insert("A", torch::tensor({5.0}));
79*da0073e9SAndroid Build Coastguard Worker   wrongDict->insert("D", torch::tensor({5.0}));
80*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(dict->update(*wrongDict), "Parameter 'D' is not defined");
81*da0073e9SAndroid Build Coastguard Worker   rightDict->insert("A", torch::tensor({5.0}));
82*da0073e9SAndroid Build Coastguard Worker   dict->update(*rightDict);
83*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
84*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::eq(dict["A"], torch::tensor({5.0})).item<bool>());
85*da0073e9SAndroid Build Coastguard Worker }
86*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,Keys)87*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, Keys) {
88*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> params = {
89*da0073e9SAndroid Build Coastguard Worker       {"a", torch::tensor({1.0})},
90*da0073e9SAndroid Build Coastguard Worker       {"b", torch::tensor({2.0})},
91*da0073e9SAndroid Build Coastguard Worker       {"c", torch::tensor({1.0, 2.0})}};
92*da0073e9SAndroid Build Coastguard Worker   auto dict = torch::nn::ParameterDict(params);
93*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> keys = dict->keys();
94*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> true_keys{"a", "b", "c"};
95*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(keys, true_keys);
96*da0073e9SAndroid Build Coastguard Worker }
97*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,Values)98*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, Values) {
99*da0073e9SAndroid Build Coastguard Worker   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
100*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
101*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tc = torch::randn({1, 2});
102*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> params = {
103*da0073e9SAndroid Build Coastguard Worker       {"a", ta}, {"b", tb}, {"c", tc}};
104*da0073e9SAndroid Build Coastguard Worker   auto dict = torch::nn::ParameterDict(params);
105*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> values = dict->values();
106*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> true_values{ta, tb, tc};
107*da0073e9SAndroid Build Coastguard Worker   for (auto i = 0U; i < values.size(); i += 1) {
108*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::all(torch::eq(values[i], true_values[i])).item<bool>());
109*da0073e9SAndroid Build Coastguard Worker   }
110*da0073e9SAndroid Build Coastguard Worker }
111*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,Get)112*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, Get) {
113*da0073e9SAndroid Build Coastguard Worker   ParameterDict dict;
114*da0073e9SAndroid Build Coastguard Worker   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
115*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
116*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tc = torch::randn({1, 2});
117*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(ta.requires_grad());
118*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(tb.requires_grad());
119*da0073e9SAndroid Build Coastguard Worker   dict->insert("A", ta);
120*da0073e9SAndroid Build Coastguard Worker   dict->insert("B", tb);
121*da0073e9SAndroid Build Coastguard Worker   dict->insert("C", tc);
122*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
123*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict->get("A"), ta)).item<bool>());
124*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->get("A").requires_grad());
125*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::all(torch::eq(dict->get("B"), tb)).item<bool>());
126*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(dict->get("B").requires_grad());
127*da0073e9SAndroid Build Coastguard Worker }
128*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParameterDictTest,PrettyPrintParameterDict)129*da0073e9SAndroid Build Coastguard Worker TEST_F(ParameterDictTest, PrettyPrintParameterDict) {
130*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> params = {
131*da0073e9SAndroid Build Coastguard Worker       {"a", torch::tensor({1.0})},
132*da0073e9SAndroid Build Coastguard Worker       {"b", torch::tensor({2.0, 1.0})},
133*da0073e9SAndroid Build Coastguard Worker       {"c", torch::tensor({{3.0}, {2.1}})},
134*da0073e9SAndroid Build Coastguard Worker       {"d", torch::tensor({{3.0, 1.3}, {1.2, 2.1}})}};
135*da0073e9SAndroid Build Coastguard Worker   auto dict = torch::nn::ParameterDict(params);
136*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
137*da0073e9SAndroid Build Coastguard Worker       c10::str(dict),
138*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ParameterDict(\n"
139*da0073e9SAndroid Build Coastguard Worker       "(a): Parameter containing: [Float of size [1]]\n"
140*da0073e9SAndroid Build Coastguard Worker       "(b): Parameter containing: [Float of size [2]]\n"
141*da0073e9SAndroid Build Coastguard Worker       "(c): Parameter containing: [Float of size [2, 1]]\n"
142*da0073e9SAndroid Build Coastguard Worker       "(d): Parameter containing: [Float of size [2, 2]]\n"
143*da0073e9SAndroid Build Coastguard Worker       ")");
144*da0073e9SAndroid Build Coastguard Worker }
145