xref: /aosp_15_r20/external/pytorch/test/cpp/api/parameterdict.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/torch.h>
3 #include <algorithm>
4 #include <memory>
5 #include <vector>
6 
7 #include <test/cpp/api/support.h>
8 
9 using namespace torch::nn;
10 using namespace torch::test;
11 
12 struct ParameterDictTest : torch::test::SeedingFixture {};
13 
TEST_F(ParameterDictTest,ConstructFromTensor)14 TEST_F(ParameterDictTest, ConstructFromTensor) {
15   ParameterDict dict;
16   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
17   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
18   torch::Tensor tc = torch::randn({1, 2});
19   ASSERT_TRUE(ta.requires_grad());
20   ASSERT_FALSE(tb.requires_grad());
21   dict->insert("A", ta);
22   dict->insert("B", tb);
23   dict->insert("C", tc);
24   ASSERT_EQ(dict->size(), 3);
25   ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
26   ASSERT_TRUE(dict["A"].requires_grad());
27   ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
28   ASSERT_FALSE(dict["B"].requires_grad());
29 }
30 
TEST_F(ParameterDictTest,ConstructFromOrderedDict)31 TEST_F(ParameterDictTest, ConstructFromOrderedDict) {
32   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
33   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
34   torch::Tensor tc = torch::randn({1, 2});
35   torch::OrderedDict<std::string, torch::Tensor> params = {
36       {"A", ta}, {"B", tb}, {"C", tc}};
37   auto dict = torch::nn::ParameterDict(params);
38   ASSERT_EQ(dict->size(), 3);
39   ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
40   ASSERT_TRUE(dict["A"].requires_grad());
41   ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
42   ASSERT_FALSE(dict["B"].requires_grad());
43 }
44 
TEST_F(ParameterDictTest,InsertAndContains)45 TEST_F(ParameterDictTest, InsertAndContains) {
46   ParameterDict dict;
47   dict->insert("A", torch::tensor({1.0}));
48   ASSERT_EQ(dict->size(), 1);
49   ASSERT_TRUE(dict->contains("A"));
50   ASSERT_FALSE(dict->contains("C"));
51 }
52 
TEST_F(ParameterDictTest,InsertAndClear)53 TEST_F(ParameterDictTest, InsertAndClear) {
54   ParameterDict dict;
55   dict->insert("A", torch::tensor({1.0}));
56   ASSERT_EQ(dict->size(), 1);
57   dict->clear();
58   ASSERT_EQ(dict->size(), 0);
59 }
60 
TEST_F(ParameterDictTest,InsertAndPop)61 TEST_F(ParameterDictTest, InsertAndPop) {
62   ParameterDict dict;
63   dict->insert("A", torch::tensor({1.0}));
64   ASSERT_EQ(dict->size(), 1);
65   ASSERT_THROWS_WITH(dict->pop("B"), "Parameter 'B' is not defined");
66   torch::Tensor p = dict->pop("A");
67   ASSERT_EQ(dict->size(), 0);
68   ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item<bool>());
69 }
70 
TEST_F(ParameterDictTest,SimpleUpdate)71 TEST_F(ParameterDictTest, SimpleUpdate) {
72   ParameterDict dict;
73   ParameterDict wrongDict;
74   ParameterDict rightDict;
75   dict->insert("A", torch::tensor({1.0}));
76   dict->insert("B", torch::tensor({2.0}));
77   dict->insert("C", torch::tensor({3.0}));
78   wrongDict->insert("A", torch::tensor({5.0}));
79   wrongDict->insert("D", torch::tensor({5.0}));
80   ASSERT_THROWS_WITH(dict->update(*wrongDict), "Parameter 'D' is not defined");
81   rightDict->insert("A", torch::tensor({5.0}));
82   dict->update(*rightDict);
83   ASSERT_EQ(dict->size(), 3);
84   ASSERT_TRUE(torch::eq(dict["A"], torch::tensor({5.0})).item<bool>());
85 }
86 
TEST_F(ParameterDictTest,Keys)87 TEST_F(ParameterDictTest, Keys) {
88   torch::OrderedDict<std::string, torch::Tensor> params = {
89       {"a", torch::tensor({1.0})},
90       {"b", torch::tensor({2.0})},
91       {"c", torch::tensor({1.0, 2.0})}};
92   auto dict = torch::nn::ParameterDict(params);
93   std::vector<std::string> keys = dict->keys();
94   std::vector<std::string> true_keys{"a", "b", "c"};
95   ASSERT_EQ(keys, true_keys);
96 }
97 
TEST_F(ParameterDictTest,Values)98 TEST_F(ParameterDictTest, Values) {
99   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
100   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
101   torch::Tensor tc = torch::randn({1, 2});
102   torch::OrderedDict<std::string, torch::Tensor> params = {
103       {"a", ta}, {"b", tb}, {"c", tc}};
104   auto dict = torch::nn::ParameterDict(params);
105   std::vector<torch::Tensor> values = dict->values();
106   std::vector<torch::Tensor> true_values{ta, tb, tc};
107   for (auto i = 0U; i < values.size(); i += 1) {
108     ASSERT_TRUE(torch::all(torch::eq(values[i], true_values[i])).item<bool>());
109   }
110 }
111 
TEST_F(ParameterDictTest,Get)112 TEST_F(ParameterDictTest, Get) {
113   ParameterDict dict;
114   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
115   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
116   torch::Tensor tc = torch::randn({1, 2});
117   ASSERT_TRUE(ta.requires_grad());
118   ASSERT_FALSE(tb.requires_grad());
119   dict->insert("A", ta);
120   dict->insert("B", tb);
121   dict->insert("C", tc);
122   ASSERT_EQ(dict->size(), 3);
123   ASSERT_TRUE(torch::all(torch::eq(dict->get("A"), ta)).item<bool>());
124   ASSERT_TRUE(dict->get("A").requires_grad());
125   ASSERT_TRUE(torch::all(torch::eq(dict->get("B"), tb)).item<bool>());
126   ASSERT_FALSE(dict->get("B").requires_grad());
127 }
128 
TEST_F(ParameterDictTest,PrettyPrintParameterDict)129 TEST_F(ParameterDictTest, PrettyPrintParameterDict) {
130   torch::OrderedDict<std::string, torch::Tensor> params = {
131       {"a", torch::tensor({1.0})},
132       {"b", torch::tensor({2.0, 1.0})},
133       {"c", torch::tensor({{3.0}, {2.1}})},
134       {"d", torch::tensor({{3.0, 1.3}, {1.2, 2.1}})}};
135   auto dict = torch::nn::ParameterDict(params);
136   ASSERT_EQ(
137       c10::str(dict),
138       "torch::nn::ParameterDict(\n"
139       "(a): Parameter containing: [Float of size [1]]\n"
140       "(b): Parameter containing: [Float of size [2]]\n"
141       "(c): Parameter containing: [Float of size [2, 1]]\n"
142       "(d): Parameter containing: [Float of size [2, 2]]\n"
143       ")");
144 }
145