xref: /aosp_15_r20/external/pytorch/test/cpp/api/modulelist.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <algorithm>
7 #include <memory>
8 #include <vector>
9 
10 #include <test/cpp/api/support.h>
11 
12 using namespace torch::nn;
13 using namespace torch::test;
14 
15 struct ModuleListTest : torch::test::SeedingFixture {};
16 
TEST_F(ModuleListTest,ConstructsFromSharedPointer)17 TEST_F(ModuleListTest, ConstructsFromSharedPointer) {
18   struct M : torch::nn::Module {
19     explicit M(int value_) : value(value_) {}
20     int value;
21   };
22   ModuleList list(
23       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
24   ASSERT_EQ(list->size(), 3);
25 }
26 
TEST_F(ModuleListTest,ConstructsFromConcreteType)27 TEST_F(ModuleListTest, ConstructsFromConcreteType) {
28   static int copy_count;
29 
30   struct M : torch::nn::Module {
31     explicit M(int value_) : value(value_) {}
32     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
33     M(const M& other) : torch::nn::Module(other) {
34       copy_count++;
35     }
36     int value;
37   };
38 
39   copy_count = 0;
40   ModuleList list(M(1), M(2), M(3));
41   ASSERT_EQ(list->size(), 3);
42   // NOTE: The current implementation expects each module to be copied exactly
43   // once, which happens when the module is passed into `std::make_shared<T>()`.
44   // TODO: Find a way to avoid copying, and then delete the copy constructor of
45   // `M`.
46   ASSERT_EQ(copy_count, 3);
47 }
48 
TEST_F(ModuleListTest,ConstructsFromModuleHolder)49 TEST_F(ModuleListTest, ConstructsFromModuleHolder) {
50   struct MImpl : torch::nn::Module {
51     explicit MImpl(int value_) : value(value_) {}
52     int value;
53   };
54 
55   struct M : torch::nn::ModuleHolder<MImpl> {
56     using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
57     using torch::nn::ModuleHolder<MImpl>::get;
58   };
59 
60   ModuleList list(M(1), M(2), M(3));
61   ASSERT_EQ(list->size(), 3);
62 }
63 
TEST_F(ModuleListTest,PushBackAddsAnElement)64 TEST_F(ModuleListTest, PushBackAddsAnElement) {
65   struct M : torch::nn::Module {
66     explicit M(int value_) : value(value_) {}
67     int value;
68   };
69 
70   ModuleList list;
71   ASSERT_EQ(list->size(), 0);
72   ASSERT_TRUE(list->is_empty());
73   list->push_back(Linear(3, 4));
74   ASSERT_EQ(list->size(), 1);
75   list->push_back(std::make_shared<M>(1));
76   ASSERT_EQ(list->size(), 2);
77   list->push_back(M(2));
78   ASSERT_EQ(list->size(), 3);
79 }
80 
TEST_F(ModuleListTest,Insertion)81 TEST_F(ModuleListTest, Insertion) {
82   struct MImpl : torch::nn::Module {
83     explicit MImpl(int value_) : value(value_) {}
84     int value;
85   };
86   TORCH_MODULE(M);
87 
88   ModuleList list;
89   list->push_back(MImpl(1));
90   ASSERT_EQ(list->size(), 1);
91   list->insert(0, std::make_shared<MImpl>(2));
92   ASSERT_EQ(list->size(), 2);
93   list->insert(1, M(3));
94   ASSERT_EQ(list->size(), 3);
95   list->insert(3, M(4));
96   ASSERT_EQ(list->size(), 4);
97   ASSERT_EQ(list->at<MImpl>(0).value, 2);
98   ASSERT_EQ(list->at<MImpl>(1).value, 3);
99   ASSERT_EQ(list->at<MImpl>(2).value, 1);
100   ASSERT_EQ(list->at<MImpl>(3).value, 4);
101 
102   std::unordered_map<size_t, size_t> U = {{0, 2}, {1, 3}, {2, 1}, {3, 4}};
103   for (const auto& P : list->named_modules("", false))
104     ASSERT_EQ(U[std::stoul(P.key())], P.value()->as<M>()->value);
105 }
106 
TEST_F(ModuleListTest,AccessWithAt)107 TEST_F(ModuleListTest, AccessWithAt) {
108   struct M : torch::nn::Module {
109     explicit M(int value_) : value(value_) {}
110     int value;
111   };
112   std::vector<std::shared_ptr<M>> modules = {
113       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
114 
115   ModuleList list;
116   for (auto& module : modules) {
117     list->push_back(module);
118   }
119   ASSERT_EQ(list->size(), 3);
120 
121   // returns the correct module for a given index
122   for (const auto i : c10::irange(modules.size())) {
123     ASSERT_EQ(&list->at<M>(i), modules[i].get());
124   }
125 
126   // throws for a bad index
127   ASSERT_THROWS_WITH(list->at<M>(modules.size() + 1), "Index out of range");
128   ASSERT_THROWS_WITH(
129       list->at<M>(modules.size() + 1000000), "Index out of range");
130 }
131 
TEST_F(ModuleListTest,AccessWithPtr)132 TEST_F(ModuleListTest, AccessWithPtr) {
133   struct M : torch::nn::Module {
134     explicit M(int value_) : value(value_) {}
135     int value;
136   };
137   std::vector<std::shared_ptr<M>> modules = {
138       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
139 
140   ModuleList list;
141   for (auto& module : modules) {
142     list->push_back(module);
143   }
144   ASSERT_EQ(list->size(), 3);
145 
146   // returns the correct module for a given index
147   for (const auto i : c10::irange(modules.size())) {
148     ASSERT_EQ(list->ptr(i).get(), modules[i].get());
149     ASSERT_EQ(list[i].get(), modules[i].get());
150     ASSERT_EQ(list->ptr<M>(i).get(), modules[i].get());
151   }
152 
153   // throws for a bad index
154   ASSERT_THROWS_WITH(list->ptr(modules.size() + 1), "Index out of range");
155   ASSERT_THROWS_WITH(list->ptr(modules.size() + 1000000), "Index out of range");
156 }
157 
TEST_F(ModuleListTest,SanityCheckForHoldingStandardModules)158 TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
159   ModuleList list(
160       Linear(10, 3),
161       Conv2d(1, 2, 3),
162       Dropout(0.5),
163       BatchNorm2d(5),
164       Embedding(4, 10),
165       LSTM(4, 5));
166 }
167 
TEST_F(ModuleListTest,ExtendPushesModulesFromOtherModuleList)168 TEST_F(ModuleListTest, ExtendPushesModulesFromOtherModuleList) {
169   struct A : torch::nn::Module {};
170   struct B : torch::nn::Module {};
171   struct C : torch::nn::Module {};
172   struct D : torch::nn::Module {};
173   ModuleList a(A{}, B{});
174   ModuleList b(C{}, D{});
175   a->extend(*b);
176 
177   ASSERT_EQ(a->size(), 4);
178   ASSERT_TRUE(a[0]->as<A>());
179   ASSERT_TRUE(a[1]->as<B>());
180   ASSERT_TRUE(a[2]->as<C>());
181   ASSERT_TRUE(a[3]->as<D>());
182 
183   ASSERT_EQ(b->size(), 2);
184   ASSERT_TRUE(b[0]->as<C>());
185   ASSERT_TRUE(b[1]->as<D>());
186 
187   std::vector<std::shared_ptr<A>> c = {
188       std::make_shared<A>(), std::make_shared<A>()};
189   b->extend(c);
190 
191   ASSERT_EQ(b->size(), 4);
192   ASSERT_TRUE(b[0]->as<C>());
193   ASSERT_TRUE(b[1]->as<D>());
194   ASSERT_TRUE(b[2]->as<A>());
195   ASSERT_TRUE(b[3]->as<A>());
196 }
197 
TEST_F(ModuleListTest,HasReferenceSemantics)198 TEST_F(ModuleListTest, HasReferenceSemantics) {
199   ModuleList first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
200   ModuleList second(first);
201 
202   ASSERT_EQ(first.get(), second.get());
203   ASSERT_EQ(first->size(), second->size());
204   ASSERT_TRUE(std::equal(
205       first->begin(),
206       first->end(),
207       second->begin(),
208       [](const std::shared_ptr<Module>& first,
209          const std::shared_ptr<Module>& second) {
210         return first.get() == second.get();
211       }));
212 }
213 
TEST_F(ModuleListTest,IsCloneable)214 TEST_F(ModuleListTest, IsCloneable) {
215   ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
216   ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
217   ASSERT_EQ(list->size(), clone->size());
218 
219   for (size_t i = 0; i < list->size(); ++i) {
220     // The modules should be the same kind (type).
221     ASSERT_EQ(list[i]->name(), clone[i]->name());
222     // But not pointer-equal (distinct objects).
223     ASSERT_NE(list[i], clone[i]);
224   }
225 
226   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
227 
228   torch::NoGradGuard no_grad;
229 
230   auto params1 = list->named_parameters();
231   auto params2 = clone->named_parameters();
232   ASSERT_EQ(params1.size(), params2.size());
233   for (auto& param : params1) {
234     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
235     ASSERT_EQ(param->device(), params2[param.key()].device());
236     ASSERT_TRUE(param->allclose(params2[param.key()]));
237     param->add_(2);
238   }
239   for (auto& param : params1) {
240     ASSERT_FALSE(param->allclose(params2[param.key()]));
241   }
242 }
243 
TEST_F(ModuleListTest,RegistersElementsAsSubmodules)244 TEST_F(ModuleListTest, RegistersElementsAsSubmodules) {
245   ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
246 
247   auto modules = list->children();
248   ASSERT_TRUE(modules[0]->as<Linear>());
249   ASSERT_TRUE(modules[1]->as<Conv2d>());
250   ASSERT_TRUE(modules[2]->as<Dropout2d>());
251 }
252 
TEST_F(ModuleListTest,NestingIsPossible)253 TEST_F(ModuleListTest, NestingIsPossible) {
254   ModuleList list(
255       (ModuleList(Dropout(), Dropout())),
256       (ModuleList(Dropout(), Dropout()), Dropout()));
257 }
258 
TEST_F(ModuleListTest,CloneToDevice_CUDA)259 TEST_F(ModuleListTest, CloneToDevice_CUDA) {
260   ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
261   torch::Device device(torch::kCUDA, 0);
262   ModuleList clone =
263       std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
264   for (const auto& p : clone->parameters()) {
265     ASSERT_EQ(p.device(), device);
266   }
267   for (const auto& b : clone->buffers()) {
268     ASSERT_EQ(b.device(), device);
269   }
270 }
271 
TEST_F(ModuleListTest,PrettyPrintModuleList)272 TEST_F(ModuleListTest, PrettyPrintModuleList) {
273   ModuleList list(
274       Linear(10, 3),
275       Conv2d(1, 2, 3),
276       Dropout(0.5),
277       BatchNorm2d(5),
278       Embedding(4, 10),
279       LSTM(4, 5));
280   ASSERT_EQ(
281       c10::str(list),
282       "torch::nn::ModuleList(\n"
283       "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
284       "  (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
285       "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
286       "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
287       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
288       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
289       ")");
290 }
291 
TEST_F(ModuleListTest,RangeBasedForLoop)292 TEST_F(ModuleListTest, RangeBasedForLoop) {
293   torch::nn::ModuleList mlist(
294       torch::nn::Linear(3, 4),
295       torch::nn::BatchNorm1d(4),
296       torch::nn::Dropout(0.5));
297 
298   std::stringstream buffer;
299   for (const auto& module : *mlist) {
300     module->pretty_print(buffer);
301   }
302 }
303 
TEST_F(ModuleListTest,InvalidAt)304 TEST_F(ModuleListTest, InvalidAt) {
305   torch::nn::ModuleList m(torch::nn::Linear(1, 2));
306   ASSERT_THROWS_WITH(
307       m->at<torch::nn::Dropout2dImpl>(0), "Unable to cast module");
308 }
309