xref: /aosp_15_r20/external/pytorch/test/cpp/api/moduledict.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/torch.h>
3 #include <algorithm>
4 #include <memory>
5 #include <vector>
6 
7 #include <test/cpp/api/support.h>
8 
9 using namespace torch::nn;
10 using namespace torch::test;
11 
12 struct ModuleDictTest : torch::test::SeedingFixture {};
13 
TEST_F(ModuleDictTest,ConstructsFromList)14 TEST_F(ModuleDictTest, ConstructsFromList) {
15   struct M : Module {
16     explicit M(int value_) : value(value_) {}
17     int value;
18   };
19 
20   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
21       {"module_1", std::make_shared<M>(1)},
22       {"module_2", std::make_shared<M>(2)},
23       {"module_3", std::make_shared<M>(3)}};
24   ModuleDict dict(list);
25   ASSERT_EQ(dict->size(), 3);
26 }
27 
TEST_F(ModuleDictTest,ConstructsFromordereddict)28 TEST_F(ModuleDictTest, ConstructsFromordereddict) {
29   struct M : Module {
30     explicit M(int value_) : value(value_) {}
31     int value;
32   };
33 
34   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
35       {"module_1", std::make_shared<M>(1)},
36       {"module_2", std::make_shared<M>(2)},
37       {"module_3", std::make_shared<M>(3)},
38   };
39   ModuleDict dict(ordereddict);
40   ASSERT_EQ(dict->size(), 3);
41 }
42 
TEST_F(ModuleDictTest,UpdatePopClearContains)43 TEST_F(ModuleDictTest, UpdatePopClearContains) {
44   struct M : Module {
45     explicit M(int value_) : value(value_) {}
46     int value;
47   };
48 
49   ModuleDict dict;
50   ASSERT_TRUE(dict->empty());
51   // Update by List
52   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
53       {"module_1", std::make_shared<M>(1)}};
54   dict->update(list1);
55   ASSERT_EQ(dict->size(), 1);
56   ASSERT_TRUE(dict->contains("module_1"));
57   // Update by OrderedDict
58   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
59       {"module_2", std::make_shared<M>(2)}};
60   dict->update(ordereddict);
61   ASSERT_EQ(dict->size(), 2);
62   ASSERT_TRUE(dict->contains("module_2"));
63   // Update by another ModuleDict
64   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
65       {"module_3", std::make_shared<M>(3)}};
66   ModuleDict updatedict(list2);
67   dict->update(*updatedict);
68   ASSERT_EQ(dict->size(), 3);
69   ASSERT_TRUE(dict->contains("module_3"));
70   // Pop
71   dict->pop("module_1");
72   ASSERT_EQ(dict->size(), 2);
73   // Pop unexist
74   ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
75   // Clear
76   dict->clear();
77   ASSERT_EQ(dict->size(), 0);
78 }
79 
TEST_F(ModuleDictTest,UpdateExist)80 TEST_F(ModuleDictTest, UpdateExist) {
81   struct M : Module {
82     explicit M(int value_) : value(value_) {}
83     int value;
84   };
85   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
86       {"module_1", std::make_shared<M>(1)},
87       {"module_2", std::make_shared<M>(2)}};
88   ModuleDict dict(list1);
89   ASSERT_EQ(dict->at<M>("module_2").value, 2);
90   // Update by list
91   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
92       {"module_2", std::make_shared<M>(0)},
93       {"module_3", std::make_shared<M>(3)}};
94   dict->update(list2);
95   ASSERT_EQ(dict->size(), 3);
96   ASSERT_EQ(dict->at<M>("module_2").value, 0);
97   // Update by ordereddict
98   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
99       {"module_3", std::make_shared<M>(0)},
100       {"module_4", std::make_shared<M>(4)}};
101   dict->update(ordereddict);
102   ASSERT_EQ(dict->size(), 4);
103   ASSERT_EQ(dict->at<M>("module_3").value, 0);
104   // Update by ModuleDict
105   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
106       {"module_4", std::make_shared<M>(0)},
107       {"module_1", std::make_shared<M>(0)}};
108   ModuleDict dict2(list3);
109   dict->update(*dict2);
110   ASSERT_EQ(dict->size(), 4);
111   ASSERT_EQ(dict->at<M>("module_1").value, 0);
112   ASSERT_EQ(dict->at<M>("module_4").value, 0);
113 }
114 
TEST_F(ModuleDictTest,Keys)115 TEST_F(ModuleDictTest, Keys) {
116   struct M : Module {
117     explicit M(int value_) : value(value_) {}
118     int value;
119   };
120 
121   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
122       {"linear", Linear(10, 3).ptr()},
123       {"conv", Conv2d(1, 2, 3).ptr()},
124       {"dropout", Dropout(0.5).ptr()},
125   };
126   ModuleDict dict(ordereddict);
127   const auto& keys = dict->keys();
128   std::vector<std::string> expected{"linear", "conv", "dropout"};
129   ASSERT_EQ(keys, expected);
130   ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
131 
132   ASSERT_TRUE(dict["linear"]->as<Linear>());
133   ASSERT_TRUE(dict["conv"]->as<Conv2d>());
134   ASSERT_TRUE(dict["dropout"]->as<Dropout>());
135 }
136 
TEST_F(ModuleDictTest,Values)137 TEST_F(ModuleDictTest, Values) {
138   struct M : Module {
139     explicit M(int value_) : value(value_) {}
140     int value;
141   };
142 
143   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
144       {"module_1", std::make_shared<M>(1)},
145       {"module_2", std::make_shared<M>(2)},
146   };
147   ModuleDict dict(ordereddict);
148   const auto& values = dict->values();
149   const auto& expected = ordereddict.values();
150   ASSERT_EQ(values, expected);
151   ASSERT_TRUE(std::equal(
152       dict->begin(),
153       dict->end(),
154       ordereddict.begin(),
155       [](const auto& lhs, const auto& rhs) {
156         return lhs.value().get() == rhs.value().get();
157       }));
158 }
159 
TEST_F(ModuleDictTest,SanityCheckForHoldingStandardModules)160 TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
161   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
162       {"linear", Linear(10, 3).ptr()},
163       {"conv", Conv2d(1, 2, 3).ptr()},
164       {"dropout", Dropout(0.5).ptr()},
165       {"batch", BatchNorm2d(5).ptr()},
166       {"embedding", Embedding(4, 10).ptr()},
167       {"lstm", LSTM(4, 5).ptr()}};
168   ModuleDict dict(ordereddict);
169 }
170 
TEST_F(ModuleDictTest,HasReferenceSemantics)171 TEST_F(ModuleDictTest, HasReferenceSemantics) {
172   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
173       {"linear1", Linear(2, 3).ptr()},
174       {"linear2", Linear(3, 4).ptr()},
175       {"linear3", Linear(4, 5).ptr()},
176   };
177   ModuleDict first(ordereddict);
178   ModuleDict second(ordereddict);
179 
180   ASSERT_EQ(first->size(), second->size());
181   ASSERT_TRUE(std::equal(
182       first->begin(),
183       first->end(),
184       second->begin(),
185       [](const auto& lhs, const auto& rhs) {
186         return lhs.value().get() == rhs.value().get();
187       }));
188 }
189 
iscloneable_helper(torch::Device device)190 void iscloneable_helper(torch::Device device) {
191   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
192       {"linear", Linear(2, 3).ptr()},
193       {"relu", Functional(torch::relu).ptr()},
194       {"batch", BatchNorm1d(3).ptr()},
195   };
196   ModuleDict dict(ordereddict);
197   dict->to(device);
198   ModuleDict clone =
199       std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
200   ASSERT_EQ(dict->size(), clone->size());
201 
202   for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end();
203        ++it, ++it_c) {
204     // The key should be same
205     ASSERT_EQ(it->key(), it_c->key());
206     // The modules should be the same kind (type).
207     ASSERT_EQ(it->value()->name(), it_c->value()->name());
208     // But not pointer-equal (distinct objects).
209     ASSERT_NE(it->value(), it_c->value());
210   }
211 
212   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
213   torch::NoGradGuard no_grad;
214 
215   auto params1 = dict->named_parameters();
216   auto params2 = clone->named_parameters();
217   ASSERT_EQ(params1.size(), params2.size());
218   for (auto& param : params1) {
219     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
220     ASSERT_EQ(param->device(), params2[param.key()].device());
221     ASSERT_TRUE(param->allclose(params2[param.key()]));
222     param->add_(2);
223   }
224   for (auto& param : params1) {
225     ASSERT_FALSE(param->allclose(params2[param.key()]));
226   }
227 }
228 
TEST_F(ModuleDictTest,IsCloneable)229 TEST_F(ModuleDictTest, IsCloneable) {
230   iscloneable_helper(torch::kCPU);
231 }
232 
TEST_F(ModuleDictTest,IsCloneable_CUDA)233 TEST_F(ModuleDictTest, IsCloneable_CUDA) {
234   iscloneable_helper({torch::kCUDA, 0});
235 }
236 
TEST_F(ModuleDictTest,RegistersElementsAsSubmodules)237 TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
238   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
239       {"linear", Linear(10, 3).ptr()},
240       {"conv", Conv2d(1, 2, 3).ptr()},
241       {"test", Dropout(0.5).ptr()},
242   };
243   ModuleDict dict(ordereddict1);
244 
245   auto modules = dict->children();
246   ASSERT_TRUE(modules[0]->as<Linear>());
247   ASSERT_TRUE(modules[1]->as<Conv2d>());
248   ASSERT_TRUE(modules[2]->as<Dropout>());
249 
250   // Update Existing
251   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
252       {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}};
253   dict->update(ordereddict2);
254 
255   modules = dict->children();
256   ASSERT_TRUE(modules[0]->as<Linear>());
257   ASSERT_TRUE(modules[1]->as<Conv2d>());
258   // Keep Order
259   ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
260   ASSERT_TRUE(modules[3]->as<LSTM>());
261 }
262 
TEST_F(ModuleDictTest,CloneToDevice_CUDA)263 TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
264   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
265       {"linear", Linear(2, 3).ptr()},
266       {"relu", Functional(torch::relu).ptr()},
267       {"batch", BatchNorm1d(3).ptr()},
268   };
269   ModuleDict dict(ordereddict);
270   torch::Device device(torch::kCUDA, 0);
271   ModuleDict clone =
272       std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
273   for (const auto& p : clone->parameters()) {
274     ASSERT_EQ(p.device(), device);
275   }
276   for (const auto& b : clone->buffers()) {
277     ASSERT_EQ(b.device(), device);
278   }
279 }
280 
TEST_F(ModuleDictTest,PrettyPrintModuleDict)281 TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
282   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
283       {"linear", Linear(10, 3).ptr()},
284       {"conv", Conv2d(1, 2, 3).ptr()},
285       {"dropout", Dropout(0.5).ptr()},
286       {"batch", BatchNorm2d(5).ptr()},
287       {"embedding", Embedding(4, 10).ptr()},
288       {"lstm", LSTM(4, 5).ptr()}};
289   ModuleDict dict(ordereddict);
290 
291   ASSERT_EQ(
292       c10::str(dict),
293       "torch::nn::ModuleDict(\n"
294       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
295       "  (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
296       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
297       "  (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
298       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
299       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
300       ")");
301 }
302 
TEST_F(ModuleDictTest,InvalidAt)303 TEST_F(ModuleDictTest, InvalidAt) {
304   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
305       {"linear", Linear(10, 3).ptr()}};
306   ModuleDict dict(ordereddict);
307   ASSERT_THROWS_WITH(
308       dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
309 }
310