xref: /aosp_15_r20/external/pytorch/test/cpp/api/moduledict.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
4*da0073e9SAndroid Build Coastguard Worker #include <memory>
5*da0073e9SAndroid Build Coastguard Worker #include <vector>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
10*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct ModuleDictTest : torch::test::SeedingFixture {};
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,ConstructsFromList)14*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, ConstructsFromList) {
15*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
16*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
17*da0073e9SAndroid Build Coastguard Worker     int value;
18*da0073e9SAndroid Build Coastguard Worker   };
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
21*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(1)},
22*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(2)},
23*da0073e9SAndroid Build Coastguard Worker       {"module_3", std::make_shared<M>(3)}};
24*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(list);
25*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,ConstructsFromordereddict)28*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, ConstructsFromordereddict) {
29*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
30*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
31*da0073e9SAndroid Build Coastguard Worker     int value;
32*da0073e9SAndroid Build Coastguard Worker   };
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
35*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(1)},
36*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(2)},
37*da0073e9SAndroid Build Coastguard Worker       {"module_3", std::make_shared<M>(3)},
38*da0073e9SAndroid Build Coastguard Worker   };
39*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
40*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,UpdatePopClearContains)43*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, UpdatePopClearContains) {
44*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
45*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
46*da0073e9SAndroid Build Coastguard Worker     int value;
47*da0073e9SAndroid Build Coastguard Worker   };
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict;
50*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->empty());
51*da0073e9SAndroid Build Coastguard Worker   // Update by List
52*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
53*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(1)}};
54*da0073e9SAndroid Build Coastguard Worker   dict->update(list1);
55*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 1);
56*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->contains("module_1"));
57*da0073e9SAndroid Build Coastguard Worker   // Update by OrderedDict
58*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
59*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(2)}};
60*da0073e9SAndroid Build Coastguard Worker   dict->update(ordereddict);
61*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 2);
62*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->contains("module_2"));
63*da0073e9SAndroid Build Coastguard Worker   // Update by another ModuleDict
64*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
65*da0073e9SAndroid Build Coastguard Worker       {"module_3", std::make_shared<M>(3)}};
66*da0073e9SAndroid Build Coastguard Worker   ModuleDict updatedict(list2);
67*da0073e9SAndroid Build Coastguard Worker   dict->update(*updatedict);
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
69*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict->contains("module_3"));
70*da0073e9SAndroid Build Coastguard Worker   // Pop
71*da0073e9SAndroid Build Coastguard Worker   dict->pop("module_1");
72*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 2);
73*da0073e9SAndroid Build Coastguard Worker   // Pop unexist
74*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
75*da0073e9SAndroid Build Coastguard Worker   // Clear
76*da0073e9SAndroid Build Coastguard Worker   dict->clear();
77*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 0);
78*da0073e9SAndroid Build Coastguard Worker }
79*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,UpdateExist)80*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, UpdateExist) {
81*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
82*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
83*da0073e9SAndroid Build Coastguard Worker     int value;
84*da0073e9SAndroid Build Coastguard Worker   };
85*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
86*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(1)},
87*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(2)}};
88*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(list1);
89*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->at<M>("module_2").value, 2);
90*da0073e9SAndroid Build Coastguard Worker   // Update by list
91*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
92*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(0)},
93*da0073e9SAndroid Build Coastguard Worker       {"module_3", std::make_shared<M>(3)}};
94*da0073e9SAndroid Build Coastguard Worker   dict->update(list2);
95*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 3);
96*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->at<M>("module_2").value, 0);
97*da0073e9SAndroid Build Coastguard Worker   // Update by ordereddict
98*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
99*da0073e9SAndroid Build Coastguard Worker       {"module_3", std::make_shared<M>(0)},
100*da0073e9SAndroid Build Coastguard Worker       {"module_4", std::make_shared<M>(4)}};
101*da0073e9SAndroid Build Coastguard Worker   dict->update(ordereddict);
102*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 4);
103*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->at<M>("module_3").value, 0);
104*da0073e9SAndroid Build Coastguard Worker   // Update by ModuleDict
105*da0073e9SAndroid Build Coastguard Worker   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
106*da0073e9SAndroid Build Coastguard Worker       {"module_4", std::make_shared<M>(0)},
107*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(0)}};
108*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict2(list3);
109*da0073e9SAndroid Build Coastguard Worker   dict->update(*dict2);
110*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), 4);
111*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->at<M>("module_1").value, 0);
112*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->at<M>("module_4").value, 0);
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,Keys)115*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, Keys) {
116*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
117*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
118*da0073e9SAndroid Build Coastguard Worker     int value;
119*da0073e9SAndroid Build Coastguard Worker   };
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
122*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(10, 3).ptr()},
123*da0073e9SAndroid Build Coastguard Worker       {"conv", Conv2d(1, 2, 3).ptr()},
124*da0073e9SAndroid Build Coastguard Worker       {"dropout", Dropout(0.5).ptr()},
125*da0073e9SAndroid Build Coastguard Worker   };
126*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
127*da0073e9SAndroid Build Coastguard Worker   const auto& keys = dict->keys();
128*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> expected{"linear", "conv", "dropout"};
129*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(keys, expected);
130*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict["linear"]->as<Linear>());
133*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict["conv"]->as<Conv2d>());
134*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(dict["dropout"]->as<Dropout>());
135*da0073e9SAndroid Build Coastguard Worker }
136*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,Values)137*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, Values) {
138*da0073e9SAndroid Build Coastguard Worker   struct M : Module {
139*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
140*da0073e9SAndroid Build Coastguard Worker     int value;
141*da0073e9SAndroid Build Coastguard Worker   };
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
144*da0073e9SAndroid Build Coastguard Worker       {"module_1", std::make_shared<M>(1)},
145*da0073e9SAndroid Build Coastguard Worker       {"module_2", std::make_shared<M>(2)},
146*da0073e9SAndroid Build Coastguard Worker   };
147*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
148*da0073e9SAndroid Build Coastguard Worker   const auto& values = dict->values();
149*da0073e9SAndroid Build Coastguard Worker   const auto& expected = ordereddict.values();
150*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(values, expected);
151*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(std::equal(
152*da0073e9SAndroid Build Coastguard Worker       dict->begin(),
153*da0073e9SAndroid Build Coastguard Worker       dict->end(),
154*da0073e9SAndroid Build Coastguard Worker       ordereddict.begin(),
155*da0073e9SAndroid Build Coastguard Worker       [](const auto& lhs, const auto& rhs) {
156*da0073e9SAndroid Build Coastguard Worker         return lhs.value().get() == rhs.value().get();
157*da0073e9SAndroid Build Coastguard Worker       }));
158*da0073e9SAndroid Build Coastguard Worker }
159*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,SanityCheckForHoldingStandardModules)160*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
161*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
162*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(10, 3).ptr()},
163*da0073e9SAndroid Build Coastguard Worker       {"conv", Conv2d(1, 2, 3).ptr()},
164*da0073e9SAndroid Build Coastguard Worker       {"dropout", Dropout(0.5).ptr()},
165*da0073e9SAndroid Build Coastguard Worker       {"batch", BatchNorm2d(5).ptr()},
166*da0073e9SAndroid Build Coastguard Worker       {"embedding", Embedding(4, 10).ptr()},
167*da0073e9SAndroid Build Coastguard Worker       {"lstm", LSTM(4, 5).ptr()}};
168*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
169*da0073e9SAndroid Build Coastguard Worker }
170*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,HasReferenceSemantics)171*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, HasReferenceSemantics) {
172*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
173*da0073e9SAndroid Build Coastguard Worker       {"linear1", Linear(2, 3).ptr()},
174*da0073e9SAndroid Build Coastguard Worker       {"linear2", Linear(3, 4).ptr()},
175*da0073e9SAndroid Build Coastguard Worker       {"linear3", Linear(4, 5).ptr()},
176*da0073e9SAndroid Build Coastguard Worker   };
177*da0073e9SAndroid Build Coastguard Worker   ModuleDict first(ordereddict);
178*da0073e9SAndroid Build Coastguard Worker   ModuleDict second(ordereddict);
179*da0073e9SAndroid Build Coastguard Worker 
180*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(first->size(), second->size());
181*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(std::equal(
182*da0073e9SAndroid Build Coastguard Worker       first->begin(),
183*da0073e9SAndroid Build Coastguard Worker       first->end(),
184*da0073e9SAndroid Build Coastguard Worker       second->begin(),
185*da0073e9SAndroid Build Coastguard Worker       [](const auto& lhs, const auto& rhs) {
186*da0073e9SAndroid Build Coastguard Worker         return lhs.value().get() == rhs.value().get();
187*da0073e9SAndroid Build Coastguard Worker       }));
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker 
iscloneable_helper(torch::Device device)190*da0073e9SAndroid Build Coastguard Worker void iscloneable_helper(torch::Device device) {
191*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
192*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(2, 3).ptr()},
193*da0073e9SAndroid Build Coastguard Worker       {"relu", Functional(torch::relu).ptr()},
194*da0073e9SAndroid Build Coastguard Worker       {"batch", BatchNorm1d(3).ptr()},
195*da0073e9SAndroid Build Coastguard Worker   };
196*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
197*da0073e9SAndroid Build Coastguard Worker   dict->to(device);
198*da0073e9SAndroid Build Coastguard Worker   ModuleDict clone =
199*da0073e9SAndroid Build Coastguard Worker       std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
200*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(dict->size(), clone->size());
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker   for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end();
203*da0073e9SAndroid Build Coastguard Worker        ++it, ++it_c) {
204*da0073e9SAndroid Build Coastguard Worker     // The key should be same
205*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(it->key(), it_c->key());
206*da0073e9SAndroid Build Coastguard Worker     // The modules should be the same kind (type).
207*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(it->value()->name(), it_c->value()->name());
208*da0073e9SAndroid Build Coastguard Worker     // But not pointer-equal (distinct objects).
209*da0073e9SAndroid Build Coastguard Worker     ASSERT_NE(it->value(), it_c->value());
210*da0073e9SAndroid Build Coastguard Worker   }
211*da0073e9SAndroid Build Coastguard Worker 
212*da0073e9SAndroid Build Coastguard Worker   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
213*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard no_grad;
214*da0073e9SAndroid Build Coastguard Worker 
215*da0073e9SAndroid Build Coastguard Worker   auto params1 = dict->named_parameters();
216*da0073e9SAndroid Build Coastguard Worker   auto params2 = clone->named_parameters();
217*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(params1.size(), params2.size());
218*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
219*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
220*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(param->device(), params2[param.key()].device());
221*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(param->allclose(params2[param.key()]));
222*da0073e9SAndroid Build Coastguard Worker     param->add_(2);
223*da0073e9SAndroid Build Coastguard Worker   }
224*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
225*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(param->allclose(params2[param.key()]));
226*da0073e9SAndroid Build Coastguard Worker   }
227*da0073e9SAndroid Build Coastguard Worker }
228*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,IsCloneable)229*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, IsCloneable) {
230*da0073e9SAndroid Build Coastguard Worker   iscloneable_helper(torch::kCPU);
231*da0073e9SAndroid Build Coastguard Worker }
232*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,IsCloneable_CUDA)233*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, IsCloneable_CUDA) {
234*da0073e9SAndroid Build Coastguard Worker   iscloneable_helper({torch::kCUDA, 0});
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,RegistersElementsAsSubmodules)237*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
238*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
239*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(10, 3).ptr()},
240*da0073e9SAndroid Build Coastguard Worker       {"conv", Conv2d(1, 2, 3).ptr()},
241*da0073e9SAndroid Build Coastguard Worker       {"test", Dropout(0.5).ptr()},
242*da0073e9SAndroid Build Coastguard Worker   };
243*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict1);
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker   auto modules = dict->children();
246*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[0]->as<Linear>());
247*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[1]->as<Conv2d>());
248*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[2]->as<Dropout>());
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker   // Update Existing
251*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
252*da0073e9SAndroid Build Coastguard Worker       {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}};
253*da0073e9SAndroid Build Coastguard Worker   dict->update(ordereddict2);
254*da0073e9SAndroid Build Coastguard Worker 
255*da0073e9SAndroid Build Coastguard Worker   modules = dict->children();
256*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[0]->as<Linear>());
257*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[1]->as<Conv2d>());
258*da0073e9SAndroid Build Coastguard Worker   // Keep Order
259*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
260*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[3]->as<LSTM>());
261*da0073e9SAndroid Build Coastguard Worker }
262*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,CloneToDevice_CUDA)263*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
264*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
265*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(2, 3).ptr()},
266*da0073e9SAndroid Build Coastguard Worker       {"relu", Functional(torch::relu).ptr()},
267*da0073e9SAndroid Build Coastguard Worker       {"batch", BatchNorm1d(3).ptr()},
268*da0073e9SAndroid Build Coastguard Worker   };
269*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
270*da0073e9SAndroid Build Coastguard Worker   torch::Device device(torch::kCUDA, 0);
271*da0073e9SAndroid Build Coastguard Worker   ModuleDict clone =
272*da0073e9SAndroid Build Coastguard Worker       std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
273*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : clone->parameters()) {
274*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(p.device(), device);
275*da0073e9SAndroid Build Coastguard Worker   }
276*da0073e9SAndroid Build Coastguard Worker   for (const auto& b : clone->buffers()) {
277*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.device(), device);
278*da0073e9SAndroid Build Coastguard Worker   }
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,PrettyPrintModuleDict)281*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
282*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
283*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(10, 3).ptr()},
284*da0073e9SAndroid Build Coastguard Worker       {"conv", Conv2d(1, 2, 3).ptr()},
285*da0073e9SAndroid Build Coastguard Worker       {"dropout", Dropout(0.5).ptr()},
286*da0073e9SAndroid Build Coastguard Worker       {"batch", BatchNorm2d(5).ptr()},
287*da0073e9SAndroid Build Coastguard Worker       {"embedding", Embedding(4, 10).ptr()},
288*da0073e9SAndroid Build Coastguard Worker       {"lstm", LSTM(4, 5).ptr()}};
289*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
290*da0073e9SAndroid Build Coastguard Worker 
291*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
292*da0073e9SAndroid Build Coastguard Worker       c10::str(dict),
293*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ModuleDict(\n"
294*da0073e9SAndroid Build Coastguard Worker       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
295*da0073e9SAndroid Build Coastguard Worker       "  (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
296*da0073e9SAndroid Build Coastguard Worker       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
297*da0073e9SAndroid Build Coastguard Worker       "  (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
298*da0073e9SAndroid Build Coastguard Worker       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
299*da0073e9SAndroid Build Coastguard Worker       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
300*da0073e9SAndroid Build Coastguard Worker       ")");
301*da0073e9SAndroid Build Coastguard Worker }
302*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleDictTest,InvalidAt)303*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleDictTest, InvalidAt) {
304*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
305*da0073e9SAndroid Build Coastguard Worker       {"linear", Linear(10, 3).ptr()}};
306*da0073e9SAndroid Build Coastguard Worker   ModuleDict dict(ordereddict);
307*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
308*da0073e9SAndroid Build Coastguard Worker       dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
309*da0073e9SAndroid Build Coastguard Worker }
310