xref: /aosp_15_r20/external/pytorch/test/cpp/api/module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
9*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker struct AGIUnit : torch::nn::Module {};
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace test {
14*da0073e9SAndroid Build Coastguard Worker struct AGIUnit : torch::nn::Module {};
15*da0073e9SAndroid Build Coastguard Worker struct AGIUnit2 : torch::nn::Module {
AGIUnit2test::AGIUnit216*da0073e9SAndroid Build Coastguard Worker   AGIUnit2() : torch::nn::Module("Foo") {}
17*da0073e9SAndroid Build Coastguard Worker };
18*da0073e9SAndroid Build Coastguard Worker } // namespace test
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker struct ModuleTest : torch::test::SeedingFixture {};
21*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CanEnableAndDisableTrainingMode)22*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
23*da0073e9SAndroid Build Coastguard Worker   Linear module(3, 4);
24*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module->is_training());
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker   module->eval();
27*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(module->is_training());
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker   module->train();
30*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module->is_training());
31*da0073e9SAndroid Build Coastguard Worker }
32*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ZeroGrad)33*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ZeroGrad) {
34*da0073e9SAndroid Build Coastguard Worker   Linear module(3, 4);
35*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::ones({8, 3}, torch::requires_grad());
36*da0073e9SAndroid Build Coastguard Worker   auto loss = module(weight).sum();
37*da0073e9SAndroid Build Coastguard Worker   loss.backward();
38*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : module->parameters()) {
39*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
40*da0073e9SAndroid Build Coastguard Worker     auto grad = parameter.grad();
41*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(grad.defined());
42*da0073e9SAndroid Build Coastguard Worker     ASSERT_NE(grad.sum().item<float>(), 0);
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker   module->zero_grad();
45*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : module->parameters()) {
46*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
47*da0073e9SAndroid Build Coastguard Worker     auto grad = parameter.grad();
48*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(grad.defined());
49*da0073e9SAndroid Build Coastguard Worker   }
50*da0073e9SAndroid Build Coastguard Worker }
51*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ZeroGradWithUndefined)52*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ZeroGradWithUndefined) {
53*da0073e9SAndroid Build Coastguard Worker   struct TestModule : torch::nn::Module {
54*da0073e9SAndroid Build Coastguard Worker     TestModule() {
55*da0073e9SAndroid Build Coastguard Worker       x = register_parameter("x", torch::ones(5, torch::requires_grad()));
56*da0073e9SAndroid Build Coastguard Worker       y = register_parameter("y", torch::ones(5, torch::requires_grad()));
57*da0073e9SAndroid Build Coastguard Worker     }
58*da0073e9SAndroid Build Coastguard Worker     torch::Tensor x, y;
59*da0073e9SAndroid Build Coastguard Worker   };
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   TestModule module;
62*da0073e9SAndroid Build Coastguard Worker   auto z = module.x * 2;
63*da0073e9SAndroid Build Coastguard Worker   z.sum().backward();
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module.x.grad().defined());
66*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(module.y.grad().defined());
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker   module.zero_grad(false); // set_to_none = false
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module.x.grad().defined());
71*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(module.y.grad().defined());
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker   module.zero_grad();
76*da0073e9SAndroid Build Coastguard Worker 
77*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(module.x.grad().defined());
78*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(module.y.grad().defined());
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterModuleThrowsForEmptyOrDottedName)81*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
82*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
83*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
84*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85*da0073e9SAndroid Build Coastguard Worker       "Submodule name must not contain a dot (got 'name.with.dot')");
86*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
87*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88*da0073e9SAndroid Build Coastguard Worker       "Submodule name must not be empty");
89*da0073e9SAndroid Build Coastguard Worker }
90*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterModuleThrowsForDuplicateModuleName)91*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
93*da0073e9SAndroid Build Coastguard Worker   TestModel model;
94*da0073e9SAndroid Build Coastguard Worker   model.register_module("linear", torch::nn::Linear(3, 4));
95*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
96*da0073e9SAndroid Build Coastguard Worker       model.register_module("linear", torch::nn::Linear(3, 4)),
97*da0073e9SAndroid Build Coastguard Worker       "Submodule 'linear' already defined");
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ReplaceModuleThrowsForUnknownModuleName)100*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
101*da0073e9SAndroid Build Coastguard Worker   torch::nn::Module model;
102*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
103*da0073e9SAndroid Build Coastguard Worker       model.replace_module("linear", torch::nn::Linear(3, 4)),
104*da0073e9SAndroid Build Coastguard Worker       "Submodule 'linear' is not defined");
105*da0073e9SAndroid Build Coastguard Worker }
106*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ReplaceModule)107*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ReplaceModule) {
108*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {
109*da0073e9SAndroid Build Coastguard Worker     torch::nn::Linear l1{nullptr};
110*da0073e9SAndroid Build Coastguard Worker     TestModel() {
111*da0073e9SAndroid Build Coastguard Worker       l1 = register_module("l1", torch::nn::Linear(3, 4));
112*da0073e9SAndroid Build Coastguard Worker     }
113*da0073e9SAndroid Build Coastguard Worker   };
114*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<TestModel>();
115*da0073e9SAndroid Build Coastguard Worker   model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
116*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
117*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,UnregisterModule)120*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, UnregisterModule) {
121*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
122*da0073e9SAndroid Build Coastguard Worker   TestModel model;
123*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
124*da0073e9SAndroid Build Coastguard Worker       model.unregister_module("linear"),
125*da0073e9SAndroid Build Coastguard Worker       "No Module with name `linear` is registered");
126*da0073e9SAndroid Build Coastguard Worker   model.register_module("linear", torch::nn::Linear(3, 4));
127*da0073e9SAndroid Build Coastguard Worker   model.unregister_module("linear");
128*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(model.children().empty());
129*da0073e9SAndroid Build Coastguard Worker }
130*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterParameterThrowsForEmptyOrDottedName)131*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
132*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
133*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
134*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
135*da0073e9SAndroid Build Coastguard Worker       "Parameter name must not contain a dot (got 'name.with.dot')");
136*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
137*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_parameter("", torch::ones(5)),
138*da0073e9SAndroid Build Coastguard Worker       "Parameter name must not be empty");
139*da0073e9SAndroid Build Coastguard Worker }
140*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterParameterThrowsForDuplicateModuleName)141*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
142*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
143*da0073e9SAndroid Build Coastguard Worker   TestModel model;
144*da0073e9SAndroid Build Coastguard Worker   model.register_parameter("p", torch::ones(5));
145*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
146*da0073e9SAndroid Build Coastguard Worker       model.register_parameter("p", torch::ones(5)),
147*da0073e9SAndroid Build Coastguard Worker       "Parameter 'p' already defined");
148*da0073e9SAndroid Build Coastguard Worker }
149*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterParameterUndefinedTensor)150*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
151*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
152*da0073e9SAndroid Build Coastguard Worker   {
153*da0073e9SAndroid Build Coastguard Worker     TestModel model;
154*da0073e9SAndroid Build Coastguard Worker     model.register_parameter(
155*da0073e9SAndroid Build Coastguard Worker         "undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
156*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(model.parameters().size(), 0);
157*da0073e9SAndroid Build Coastguard Worker   }
158*da0073e9SAndroid Build Coastguard Worker   {
159*da0073e9SAndroid Build Coastguard Worker     WarningCapture warnings;
160*da0073e9SAndroid Build Coastguard Worker 
161*da0073e9SAndroid Build Coastguard Worker     TestModel model;
162*da0073e9SAndroid Build Coastguard Worker     model.register_parameter("undefined_tensor", torch::Tensor());
163*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(model.parameters().size(), 0);
164*da0073e9SAndroid Build Coastguard Worker 
165*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
166*da0073e9SAndroid Build Coastguard Worker         count_substr_occurrences(
167*da0073e9SAndroid Build Coastguard Worker             warnings.str(),
168*da0073e9SAndroid Build Coastguard Worker             "Ignoring the `requires_grad=true` function parameter"),
169*da0073e9SAndroid Build Coastguard Worker         1);
170*da0073e9SAndroid Build Coastguard Worker   }
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterBufferThrowsForEmptyOrDottedName)173*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
174*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
175*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
176*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
177*da0073e9SAndroid Build Coastguard Worker       "Buffer name must not contain a dot (got 'name.with.dot')");
178*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
179*da0073e9SAndroid Build Coastguard Worker       TestModel{}.register_buffer("", torch::ones(5)),
180*da0073e9SAndroid Build Coastguard Worker       "Buffer name must not be empty");
181*da0073e9SAndroid Build Coastguard Worker }
182*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,RegisterBufferThrowsForDuplicateModuleName)183*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
184*da0073e9SAndroid Build Coastguard Worker   struct TestModel : public torch::nn::Module {};
185*da0073e9SAndroid Build Coastguard Worker   TestModel model;
186*da0073e9SAndroid Build Coastguard Worker   model.register_buffer("p", torch::ones(5));
187*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
188*da0073e9SAndroid Build Coastguard Worker       model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
189*da0073e9SAndroid Build Coastguard Worker }
190*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CanGetName)191*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CanGetName) {
192*da0073e9SAndroid Build Coastguard Worker   // CHECK instead of REQUIRE because demangling may fail.
193*da0073e9SAndroid Build Coastguard Worker   AGIUnit agi;
194*da0073e9SAndroid Build Coastguard Worker   // Call it twice just to make sure there are no bugs in the lazy
195*da0073e9SAndroid Build Coastguard Worker   // initialization semantics.
196*da0073e9SAndroid Build Coastguard Worker   EXPECT_EQ(agi.name(), "AGIUnit");
197*da0073e9SAndroid Build Coastguard Worker   EXPECT_EQ(agi.name(), "AGIUnit");
198*da0073e9SAndroid Build Coastguard Worker   EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
199*da0073e9SAndroid Build Coastguard Worker   EXPECT_EQ(test::AGIUnit2().name(), "Foo");
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,AsCastsModulesCorrectly)202*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, AsCastsModulesCorrectly) {
203*da0073e9SAndroid Build Coastguard Worker   Linear module(3, 4);
204*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module->as<Linear>(), module.get());
205*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module->as<LinearImpl>(), module.get());
206*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module->as<Module>(), module.get());
207*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module->as<AGIUnit>(), nullptr);
208*da0073e9SAndroid Build Coastguard Worker 
209*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<Module> raw = module.ptr();
210*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw->as<Linear>(), module.get());
211*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw->as<LinearImpl>(), module.get());
212*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw->as<Module>(), module.get());
213*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
214*da0073e9SAndroid Build Coastguard Worker 
215*da0073e9SAndroid Build Coastguard Worker   Module& raw_ref = *raw.get();
216*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw_ref.as<Linear>(), module.get());
217*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
218*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw_ref.as<Module>(), module.get());
219*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
220*da0073e9SAndroid Build Coastguard Worker   if (auto* linear = raw_ref.as<Linear>()) {
221*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(linear->weight.ndimension(), 2);
222*da0073e9SAndroid Build Coastguard Worker   }
223*da0073e9SAndroid Build Coastguard Worker 
224*da0073e9SAndroid Build Coastguard Worker   AGIUnit unit;
225*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(unit.as<Linear>(), nullptr);
226*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
227*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(unit.as<AGIUnit>(), &unit);
228*da0073e9SAndroid Build Coastguard Worker }
229*da0073e9SAndroid Build Coastguard Worker 
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::Device to_device,torch::Dtype to_dtype)230*da0073e9SAndroid Build Coastguard Worker void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
231*da0073e9SAndroid Build Coastguard Worker     torch::Device to_device,
232*da0073e9SAndroid Build Coastguard Worker     torch::Dtype to_dtype) {
233*da0073e9SAndroid Build Coastguard Worker   {
234*da0073e9SAndroid Build Coastguard Worker     // Case 1: Undefined tensors as parameters
235*da0073e9SAndroid Build Coastguard Worker     Linear module(LinearOptions(10, 20).bias(false));
236*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
237*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->bias.defined());
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker     module->to(to_device);
240*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
241*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module->weight.device().type(), to_device.type());
242*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->bias.defined());
243*da0073e9SAndroid Build Coastguard Worker 
244*da0073e9SAndroid Build Coastguard Worker     module->to(to_dtype);
245*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
246*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module->weight.dtype(), to_dtype);
247*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->bias.defined());
248*da0073e9SAndroid Build Coastguard Worker   }
249*da0073e9SAndroid Build Coastguard Worker   {
250*da0073e9SAndroid Build Coastguard Worker     // Case 2: Undefined tensors as buffers
251*da0073e9SAndroid Build Coastguard Worker     BatchNorm1d module(
252*da0073e9SAndroid Build Coastguard Worker         BatchNorm1dOptions(5).track_running_stats(false).affine(true));
253*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
254*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->running_mean.defined());
255*da0073e9SAndroid Build Coastguard Worker 
256*da0073e9SAndroid Build Coastguard Worker     module->to(to_device);
257*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
258*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module->weight.device().type(), to_device.type());
259*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->running_mean.defined());
260*da0073e9SAndroid Build Coastguard Worker 
261*da0073e9SAndroid Build Coastguard Worker     module->to(to_dtype);
262*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(module->weight.defined());
263*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module->weight.dtype(), to_dtype);
264*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(module->running_mean.defined());
265*da0073e9SAndroid Build Coastguard Worker   }
266*da0073e9SAndroid Build Coastguard Worker }
267*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,DeviceOrDtypeConversionSkipsUndefinedTensor)268*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
269*da0073e9SAndroid Build Coastguard Worker   test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
270*da0073e9SAndroid Build Coastguard Worker }
271*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA)272*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
273*da0073e9SAndroid Build Coastguard Worker   test_DeviceOrDtypeConversionSkipsUndefinedTensor(
274*da0073e9SAndroid Build Coastguard Worker       torch::kCUDA, torch::kDouble);
275*da0073e9SAndroid Build Coastguard Worker }
276*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ParametersAndBuffersAccessorSkipsUndefinedTensor)277*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
278*da0073e9SAndroid Build Coastguard Worker   {
279*da0073e9SAndroid Build Coastguard Worker     Linear module(LinearOptions(10, 20).bias(false));
280*da0073e9SAndroid Build Coastguard Worker 
281*da0073e9SAndroid Build Coastguard Worker     auto params = module->parameters();
282*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(params.size(), 1);
283*da0073e9SAndroid Build Coastguard Worker     auto named_params = module->named_parameters();
284*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(named_params.size(), 1);
285*da0073e9SAndroid Build Coastguard Worker 
286*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
287*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
288*da0073e9SAndroid Build Coastguard Worker   }
289*da0073e9SAndroid Build Coastguard Worker   {
290*da0073e9SAndroid Build Coastguard Worker     BatchNorm1d module(
291*da0073e9SAndroid Build Coastguard Worker         BatchNorm1dOptions(5).track_running_stats(false).affine(false));
292*da0073e9SAndroid Build Coastguard Worker 
293*da0073e9SAndroid Build Coastguard Worker     auto buffers = module->buffers();
294*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffers.size(), 0);
295*da0073e9SAndroid Build Coastguard Worker     auto named_buffers = module->named_buffers();
296*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(named_buffers.size(), 0);
297*da0073e9SAndroid Build Coastguard Worker   }
298*da0073e9SAndroid Build Coastguard Worker   {
299*da0073e9SAndroid Build Coastguard Worker     BatchNorm1d module(
300*da0073e9SAndroid Build Coastguard Worker         BatchNorm1dOptions(5).track_running_stats(true).affine(false));
301*da0073e9SAndroid Build Coastguard Worker 
302*da0073e9SAndroid Build Coastguard Worker     auto buffers = module->buffers();
303*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffers.size(), 3);
304*da0073e9SAndroid Build Coastguard Worker     auto named_buffers = module->named_buffers();
305*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(named_buffers.size(), 3);
306*da0073e9SAndroid Build Coastguard Worker 
307*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
308*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
309*da0073e9SAndroid Build Coastguard Worker         pointer_equal(named_buffers["running_mean"], module->running_mean));
310*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
311*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
312*da0073e9SAndroid Build Coastguard Worker         pointer_equal(named_buffers["running_var"], module->running_var));
313*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
314*da0073e9SAndroid Build Coastguard Worker         pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
315*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(pointer_equal(
316*da0073e9SAndroid Build Coastguard Worker         named_buffers["num_batches_tracked"], module->num_batches_tracked));
317*da0073e9SAndroid Build Coastguard Worker   }
318*da0073e9SAndroid Build Coastguard Worker }
319*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,Conversion_MultiCUDA)320*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, Conversion_MultiCUDA) {
321*da0073e9SAndroid Build Coastguard Worker   Linear module(128, 64);
322*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : module->parameters()) {
323*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
324*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.dtype(), torch::kFloat32);
325*da0073e9SAndroid Build Coastguard Worker   }
326*da0073e9SAndroid Build Coastguard Worker   {
327*da0073e9SAndroid Build Coastguard Worker     module->to({torch::kCUDA, 0});
328*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
329*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
330*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().index(), 0);
331*da0073e9SAndroid Build Coastguard Worker     }
332*da0073e9SAndroid Build Coastguard Worker     module->to({torch::kCUDA, 1});
333*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
334*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
335*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().index(), 1);
336*da0073e9SAndroid Build Coastguard Worker     }
337*da0073e9SAndroid Build Coastguard Worker   }
338*da0073e9SAndroid Build Coastguard Worker   {
339*da0073e9SAndroid Build Coastguard Worker     module->to(torch::Device(torch::kCPU));
340*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
341*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
342*da0073e9SAndroid Build Coastguard Worker     }
343*da0073e9SAndroid Build Coastguard Worker   }
344*da0073e9SAndroid Build Coastguard Worker   {
345*da0073e9SAndroid Build Coastguard Worker     module->to(torch::kFloat64);
346*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
347*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.dtype(), torch::kFloat64);
348*da0073e9SAndroid Build Coastguard Worker     }
349*da0073e9SAndroid Build Coastguard Worker   }
350*da0073e9SAndroid Build Coastguard Worker }
351*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,Conversion_NoGrad_MultiCUDA)352*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
353*da0073e9SAndroid Build Coastguard Worker   Linear module(128, 64);
354*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : module->parameters()) {
355*da0073e9SAndroid Build Coastguard Worker     parameter.requires_grad_(false);
356*da0073e9SAndroid Build Coastguard Worker   }
357*da0073e9SAndroid Build Coastguard Worker   {
358*da0073e9SAndroid Build Coastguard Worker     module->to(torch::kInt32);
359*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
360*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.dtype(), torch::kInt32);
361*da0073e9SAndroid Build Coastguard Worker     }
362*da0073e9SAndroid Build Coastguard Worker   }
363*da0073e9SAndroid Build Coastguard Worker   {
364*da0073e9SAndroid Build Coastguard Worker     module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
365*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
366*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
367*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.device().index(), 1);
368*da0073e9SAndroid Build Coastguard Worker     }
369*da0073e9SAndroid Build Coastguard Worker     for (auto& parameter : module->parameters()) {
370*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(parameter.dtype(), torch::kUInt8);
371*da0073e9SAndroid Build Coastguard Worker     }
372*da0073e9SAndroid Build Coastguard Worker   }
373*da0073e9SAndroid Build Coastguard Worker }
374*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CallingCloneOnModuleThatDoesNotOverrideCloneThrows)375*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
376*da0073e9SAndroid Build Coastguard Worker   struct UnCloneable : Module {};
377*da0073e9SAndroid Build Coastguard Worker   UnCloneable module;
378*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
379*da0073e9SAndroid Build Coastguard Worker }
380*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow)381*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
382*da0073e9SAndroid Build Coastguard Worker   struct Cloneable : Module {
383*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<Module> clone(
384*da0073e9SAndroid Build Coastguard Worker         const torch::optional<torch::Device>& device =
385*da0073e9SAndroid Build Coastguard Worker             torch::nullopt) const override {
386*da0073e9SAndroid Build Coastguard Worker       return nullptr;
387*da0073e9SAndroid Build Coastguard Worker     }
388*da0073e9SAndroid Build Coastguard Worker   };
389*da0073e9SAndroid Build Coastguard Worker   Cloneable module;
390*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
391*da0073e9SAndroid Build Coastguard Worker   ASSERT_NO_THROW({ module.clone(); });
392*da0073e9SAndroid Build Coastguard Worker }
393*da0073e9SAndroid Build Coastguard Worker 
394*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-exception-escape)
395*da0073e9SAndroid Build Coastguard Worker struct TestDistinctParametersModule
396*da0073e9SAndroid Build Coastguard Worker     : public Cloneable<TestDistinctParametersModule> {
TestDistinctParametersModuleTestDistinctParametersModule397*da0073e9SAndroid Build Coastguard Worker   TestDistinctParametersModule() {
398*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
399*da0073e9SAndroid Build Coastguard Worker     reset();
400*da0073e9SAndroid Build Coastguard Worker   }
resetTestDistinctParametersModule401*da0073e9SAndroid Build Coastguard Worker   void reset() override {
402*da0073e9SAndroid Build Coastguard Worker     l1 = register_module("l1", Linear(10, 3));
403*da0073e9SAndroid Build Coastguard Worker     l2 = register_module("l2", Linear(3, 5));
404*da0073e9SAndroid Build Coastguard Worker     l3 = register_module("l3", Linear(5, 100));
405*da0073e9SAndroid Build Coastguard Worker     buffer = register_buffer("buf", torch::ones({2, 2}));
406*da0073e9SAndroid Build Coastguard Worker   }
407*da0073e9SAndroid Build Coastguard Worker 
408*da0073e9SAndroid Build Coastguard Worker   Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
409*da0073e9SAndroid Build Coastguard Worker   torch::Tensor buffer;
410*da0073e9SAndroid Build Coastguard Worker };
411*da0073e9SAndroid Build Coastguard Worker 
testDistinctParameters(std::shared_ptr<Module> m1,std::shared_ptr<Module> m2)412*da0073e9SAndroid Build Coastguard Worker void testDistinctParameters(
413*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<Module> m1,
414*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<Module> m2) {
415*da0073e9SAndroid Build Coastguard Worker   auto params1 = m1->named_parameters();
416*da0073e9SAndroid Build Coastguard Worker   auto params2 = m2->named_parameters();
417*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(params1.size(), 6);
418*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(params2.size(), 6);
419*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
420*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
421*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(param->allclose(params2[param.key()]));
422*da0073e9SAndroid Build Coastguard Worker     param->add_(2);
423*da0073e9SAndroid Build Coastguard Worker   }
424*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
425*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(param->allclose(params2[param.key()]));
426*da0073e9SAndroid Build Coastguard Worker   }
427*da0073e9SAndroid Build Coastguard Worker 
428*da0073e9SAndroid Build Coastguard Worker   auto buffers1 = m1->named_buffers();
429*da0073e9SAndroid Build Coastguard Worker   auto buffers2 = m2->named_buffers();
430*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers1.size(), 1);
431*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers2.size(), 1);
432*da0073e9SAndroid Build Coastguard Worker   for (auto& buffer : buffers1) {
433*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
434*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
435*da0073e9SAndroid Build Coastguard Worker     buffer->add_(2);
436*da0073e9SAndroid Build Coastguard Worker   }
437*da0073e9SAndroid Build Coastguard Worker   for (auto& buffer : buffers1) {
438*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
439*da0073e9SAndroid Build Coastguard Worker   }
440*da0073e9SAndroid Build Coastguard Worker }
441*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloneCreatesDistinctParameters)442*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
443*da0073e9SAndroid Build Coastguard Worker   auto module = std::make_shared<TestDistinctParametersModule>();
444*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard no_grad;
445*da0073e9SAndroid Build Coastguard Worker   auto module2 = module->clone();
446*da0073e9SAndroid Build Coastguard Worker   testDistinctParameters(module, module2);
447*da0073e9SAndroid Build Coastguard Worker }
448*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloneCreatesDistinctParametersExplicitDevice_CUDA)449*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
450*da0073e9SAndroid Build Coastguard Worker   auto module = std::make_shared<TestDistinctParametersModule>();
451*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard no_grad;
452*da0073e9SAndroid Build Coastguard Worker   torch::Device device(torch::kCUDA, 0);
453*da0073e9SAndroid Build Coastguard Worker   module->to(device);
454*da0073e9SAndroid Build Coastguard Worker   auto module2 = module->clone(device);
455*da0073e9SAndroid Build Coastguard Worker   testDistinctParameters(module, module2);
456*da0073e9SAndroid Build Coastguard Worker }
457*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloneCreatesDistinctParametersExplicitDevice_MultiCUDA)458*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
459*da0073e9SAndroid Build Coastguard Worker   auto module = std::make_shared<TestDistinctParametersModule>();
460*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard no_grad;
461*da0073e9SAndroid Build Coastguard Worker   torch::Device d0(torch::kCUDA, 0);
462*da0073e9SAndroid Build Coastguard Worker   torch::Device d1(torch::kCUDA, 1);
463*da0073e9SAndroid Build Coastguard Worker   module->to(d0);
464*da0073e9SAndroid Build Coastguard Worker   auto module2 = module->clone(d1);
465*da0073e9SAndroid Build Coastguard Worker 
466*da0073e9SAndroid Build Coastguard Worker   for (auto& param : module->parameters()) {
467*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(param.device(), d0);
468*da0073e9SAndroid Build Coastguard Worker   }
469*da0073e9SAndroid Build Coastguard Worker 
470*da0073e9SAndroid Build Coastguard Worker   for (auto& param : module2->parameters()) {
471*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(param.device(), d1);
472*da0073e9SAndroid Build Coastguard Worker   }
473*da0073e9SAndroid Build Coastguard Worker 
474*da0073e9SAndroid Build Coastguard Worker   // need to move the module back to d0 as allclose expects two tensors on
475*da0073e9SAndroid Build Coastguard Worker   // the same device.
476*da0073e9SAndroid Build Coastguard Worker   module2->to(d0);
477*da0073e9SAndroid Build Coastguard Worker   testDistinctParameters(module, module2);
478*da0073e9SAndroid Build Coastguard Worker }
479*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ClonePreservesExternalReferences)480*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ClonePreservesExternalReferences) {
481*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
482*da0073e9SAndroid Build Coastguard Worker   struct TestModule : public Cloneable<TestModule> {
483*da0073e9SAndroid Build Coastguard Worker     TestModule() {
484*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
485*da0073e9SAndroid Build Coastguard Worker       reset();
486*da0073e9SAndroid Build Coastguard Worker     }
487*da0073e9SAndroid Build Coastguard Worker     void reset() override {
488*da0073e9SAndroid Build Coastguard Worker       weight = register_parameter("weight", torch::ones({4, 4}));
489*da0073e9SAndroid Build Coastguard Worker     }
490*da0073e9SAndroid Build Coastguard Worker     torch::Tensor weight;
491*da0073e9SAndroid Build Coastguard Worker   };
492*da0073e9SAndroid Build Coastguard Worker   auto module = std::make_shared<TestModule>();
493*da0073e9SAndroid Build Coastguard Worker   {
494*da0073e9SAndroid Build Coastguard Worker     torch::NoGradGuard no_grad;
495*da0073e9SAndroid Build Coastguard Worker     module->weight += 1;
496*da0073e9SAndroid Build Coastguard Worker   }
497*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
498*da0073e9SAndroid Build Coastguard Worker       pointer_equal(module->weight, module->named_parameters()["weight"]));
499*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
500*da0073e9SAndroid Build Coastguard Worker 
501*da0073e9SAndroid Build Coastguard Worker   auto module2 = std::dynamic_pointer_cast<TestModule>(
502*da0073e9SAndroid Build Coastguard Worker       std::shared_ptr<Module>(module->clone()));
503*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
504*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
505*da0073e9SAndroid Build Coastguard Worker       pointer_equal(module2->weight, module2->named_parameters()["weight"]));
506*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
507*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(module2->weight.allclose(module->weight));
508*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(
509*da0073e9SAndroid Build Coastguard Worker       pointer_equal(module2->weight, module->named_parameters()["weight"]));
510*da0073e9SAndroid Build Coastguard Worker }
511*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloneCopiesTheValuesOfVariablesOfSubmodules)512*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
513*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
514*da0073e9SAndroid Build Coastguard Worker   struct TestModule : public Cloneable<TestModule> {
515*da0073e9SAndroid Build Coastguard Worker     TestModule() {
516*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
517*da0073e9SAndroid Build Coastguard Worker       reset();
518*da0073e9SAndroid Build Coastguard Worker     }
519*da0073e9SAndroid Build Coastguard Worker     void reset() override {
520*da0073e9SAndroid Build Coastguard Worker       weight = register_parameter("weight", torch::ones({4, 4}));
521*da0073e9SAndroid Build Coastguard Worker     }
522*da0073e9SAndroid Build Coastguard Worker 
523*da0073e9SAndroid Build Coastguard Worker     torch::Tensor weight;
524*da0073e9SAndroid Build Coastguard Worker     int value = 0;
525*da0073e9SAndroid Build Coastguard Worker   };
526*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
527*da0073e9SAndroid Build Coastguard Worker   struct NestedModule : public Cloneable<NestedModule> {
528*da0073e9SAndroid Build Coastguard Worker     NestedModule() {
529*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
530*da0073e9SAndroid Build Coastguard Worker       reset();
531*da0073e9SAndroid Build Coastguard Worker     }
532*da0073e9SAndroid Build Coastguard Worker     void reset() override {
533*da0073e9SAndroid Build Coastguard Worker       module = register_module("module", std::make_shared<TestModule>());
534*da0073e9SAndroid Build Coastguard Worker     }
535*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<TestModule> module;
536*da0073e9SAndroid Build Coastguard Worker   };
537*da0073e9SAndroid Build Coastguard Worker 
538*da0073e9SAndroid Build Coastguard Worker   auto a = std::make_shared<NestedModule>();
539*da0073e9SAndroid Build Coastguard Worker   {
540*da0073e9SAndroid Build Coastguard Worker     torch::NoGradGuard no_grad;
541*da0073e9SAndroid Build Coastguard Worker     a->module->weight += 1;
542*da0073e9SAndroid Build Coastguard Worker     a->module->value = 123;
543*da0073e9SAndroid Build Coastguard Worker   }
544*da0073e9SAndroid Build Coastguard Worker 
545*da0073e9SAndroid Build Coastguard Worker   auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
546*da0073e9SAndroid Build Coastguard Worker 
547*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
548*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(pointer_equal(
549*da0073e9SAndroid Build Coastguard Worker       b->module->weight, b->module->named_parameters()["weight"]));
550*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
551*da0073e9SAndroid Build Coastguard Worker       b->module->named_parameters()["weight"].allclose(a->module->weight));
552*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
553*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(b->module->value, a->module->value);
554*da0073e9SAndroid Build Coastguard Worker }
555*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloneToDevicePreservesTheDeviceOfParameters_CUDA)556*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
557*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
558*da0073e9SAndroid Build Coastguard Worker   struct TestModule : public Cloneable<TestModule> {
559*da0073e9SAndroid Build Coastguard Worker     TestModule() {
560*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
561*da0073e9SAndroid Build Coastguard Worker       reset();
562*da0073e9SAndroid Build Coastguard Worker     }
563*da0073e9SAndroid Build Coastguard Worker     void reset() override {
564*da0073e9SAndroid Build Coastguard Worker       l1 = register_module("l1", Linear(10, 3));
565*da0073e9SAndroid Build Coastguard Worker       l2 = register_module("l2", Linear(3, 5));
566*da0073e9SAndroid Build Coastguard Worker       l3 = register_module("l3", Linear(5, 100));
567*da0073e9SAndroid Build Coastguard Worker       buffer = register_buffer("buf", torch::ones({2, 2}));
568*da0073e9SAndroid Build Coastguard Worker     }
569*da0073e9SAndroid Build Coastguard Worker 
570*da0073e9SAndroid Build Coastguard Worker     Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
571*da0073e9SAndroid Build Coastguard Worker     torch::Tensor buffer;
572*da0073e9SAndroid Build Coastguard Worker   };
573*da0073e9SAndroid Build Coastguard Worker 
574*da0073e9SAndroid Build Coastguard Worker   TestModule m;
575*da0073e9SAndroid Build Coastguard Worker   torch::Device device(torch::kCUDA, 0);
576*da0073e9SAndroid Build Coastguard Worker 
577*da0073e9SAndroid Build Coastguard Worker   m.to(device);
578*da0073e9SAndroid Build Coastguard Worker 
579*da0073e9SAndroid Build Coastguard Worker   auto clone = m.clone();
580*da0073e9SAndroid Build Coastguard Worker   for (const auto& parameter : clone->parameters()) {
581*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device().type(), device.type());
582*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device().index(), device.index());
583*da0073e9SAndroid Build Coastguard Worker   }
584*da0073e9SAndroid Build Coastguard Worker   for (const auto& buffer : clone->buffers()) {
585*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffer.device().type(), device.type());
586*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffer.device().index(), device.index());
587*da0073e9SAndroid Build Coastguard Worker   }
588*da0073e9SAndroid Build Coastguard Worker }
589*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA)590*da0073e9SAndroid Build Coastguard Worker TEST_F(
591*da0073e9SAndroid Build Coastguard Worker     ModuleTest,
592*da0073e9SAndroid Build Coastguard Worker     CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
593*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
594*da0073e9SAndroid Build Coastguard Worker   struct TestModule : public Cloneable<TestModule> {
595*da0073e9SAndroid Build Coastguard Worker     TestModule() {
596*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
597*da0073e9SAndroid Build Coastguard Worker       reset();
598*da0073e9SAndroid Build Coastguard Worker     }
599*da0073e9SAndroid Build Coastguard Worker     void reset() override {
600*da0073e9SAndroid Build Coastguard Worker       l1 = register_module("l1", Linear(10, 3));
601*da0073e9SAndroid Build Coastguard Worker       l2 = register_module("l2", Linear(3, 5));
602*da0073e9SAndroid Build Coastguard Worker       l3 = register_module("l3", Linear(5, 100));
603*da0073e9SAndroid Build Coastguard Worker       buffer = register_buffer("buf", torch::ones({2, 2}));
604*da0073e9SAndroid Build Coastguard Worker     }
605*da0073e9SAndroid Build Coastguard Worker 
606*da0073e9SAndroid Build Coastguard Worker     Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
607*da0073e9SAndroid Build Coastguard Worker     torch::Tensor buffer;
608*da0073e9SAndroid Build Coastguard Worker   };
609*da0073e9SAndroid Build Coastguard Worker 
610*da0073e9SAndroid Build Coastguard Worker   TestModule m;
611*da0073e9SAndroid Build Coastguard Worker   torch::Device device(torch::kCUDA, 1);
612*da0073e9SAndroid Build Coastguard Worker   // everything is on CPU here
613*da0073e9SAndroid Build Coastguard Worker   auto clone = m.clone(device);
614*da0073e9SAndroid Build Coastguard Worker   for (const auto& parameter : clone->parameters()) {
615*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device().type(), device.type());
616*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device().index(), device.index());
617*da0073e9SAndroid Build Coastguard Worker   }
618*da0073e9SAndroid Build Coastguard Worker   for (const auto& buffer : clone->buffers()) {
619*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffer.device().type(), device.type());
620*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(buffer.device().index(), device.index());
621*da0073e9SAndroid Build Coastguard Worker   }
622*da0073e9SAndroid Build Coastguard Worker }
623*da0073e9SAndroid Build Coastguard Worker 
624*da0073e9SAndroid Build Coastguard Worker struct ParameterTestModule : Module {
ParameterTestModuleParameterTestModule625*da0073e9SAndroid Build Coastguard Worker   ParameterTestModule() {
626*da0073e9SAndroid Build Coastguard Worker     a = register_parameter("a", torch::zeros({2, 2}));
627*da0073e9SAndroid Build Coastguard Worker     b = register_parameter("b", torch::ones({2, 2}));
628*da0073e9SAndroid Build Coastguard Worker     c = register_parameter("c", torch::ones({2, 2}) * 2);
629*da0073e9SAndroid Build Coastguard Worker   }
630*da0073e9SAndroid Build Coastguard Worker 
631*da0073e9SAndroid Build Coastguard Worker   torch::Tensor a, b, c;
632*da0073e9SAndroid Build Coastguard Worker };
633*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,HasCorrectNumberOfParameters)634*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
635*da0073e9SAndroid Build Coastguard Worker   ParameterTestModule module;
636*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module.parameters().size(), 3);
637*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module.named_parameters().size(), 3);
638*da0073e9SAndroid Build Coastguard Worker }
639*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ContainsParametersWithTheCorrectName)640*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
641*da0073e9SAndroid Build Coastguard Worker   ParameterTestModule module;
642*da0073e9SAndroid Build Coastguard Worker   auto parameters = module.named_parameters();
643*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters.contains("a"));
644*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters.contains("b"));
645*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(parameters.contains("c"));
646*da0073e9SAndroid Build Coastguard Worker }
647*da0073e9SAndroid Build Coastguard Worker 
648*da0073e9SAndroid Build Coastguard Worker struct BufferTestModule : Module {
BufferTestModuleBufferTestModule649*da0073e9SAndroid Build Coastguard Worker   BufferTestModule() {
650*da0073e9SAndroid Build Coastguard Worker     a = register_buffer("a", torch::zeros({2, 2}));
651*da0073e9SAndroid Build Coastguard Worker     b = register_buffer("b", torch::ones({2, 2}));
652*da0073e9SAndroid Build Coastguard Worker     c = register_buffer("c", torch::ones({2, 2}) * 2);
653*da0073e9SAndroid Build Coastguard Worker   }
654*da0073e9SAndroid Build Coastguard Worker 
655*da0073e9SAndroid Build Coastguard Worker   torch::Tensor a, b, c;
656*da0073e9SAndroid Build Coastguard Worker };
657*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,HasCorrectNumberOfBuffers)658*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
659*da0073e9SAndroid Build Coastguard Worker   BufferTestModule module;
660*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module.buffers().size(), 3);
661*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(module.named_buffers().size(), 3);
662*da0073e9SAndroid Build Coastguard Worker }
663*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ContainsBuffersWithTheCorrectName)664*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
665*da0073e9SAndroid Build Coastguard Worker   BufferTestModule module;
666*da0073e9SAndroid Build Coastguard Worker   auto buffers = module.named_buffers();
667*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(buffers.contains("a"));
668*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(buffers.contains("b"));
669*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(buffers.contains("c"));
670*da0073e9SAndroid Build Coastguard Worker }
671*da0073e9SAndroid Build Coastguard Worker 
672*da0073e9SAndroid Build Coastguard Worker struct AImpl : torch::nn::Module {
AImplAImpl673*da0073e9SAndroid Build Coastguard Worker   AImpl() : x_(123) {}
AImplAImpl674*da0073e9SAndroid Build Coastguard Worker   AImpl(int x) : x_(x) {}
675*da0073e9SAndroid Build Coastguard Worker   int x_;
676*da0073e9SAndroid Build Coastguard Worker };
677*da0073e9SAndroid Build Coastguard Worker TORCH_MODULE(A);
678*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl)679*da0073e9SAndroid Build Coastguard Worker TEST_F(
680*da0073e9SAndroid Build Coastguard Worker     ModuleTest,
681*da0073e9SAndroid Build Coastguard Worker     DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
682*da0073e9SAndroid Build Coastguard Worker   A a;
683*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a);
684*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(a.is_empty());
685*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(a->x_, 123);
686*da0073e9SAndroid Build Coastguard Worker }
687*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl)688*da0073e9SAndroid Build Coastguard Worker TEST_F(
689*da0073e9SAndroid Build Coastguard Worker     ModuleTest,
690*da0073e9SAndroid Build Coastguard Worker     ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
691*da0073e9SAndroid Build Coastguard Worker   A a(5);
692*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a);
693*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(a.is_empty());
694*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(a->x_, 5);
695*da0073e9SAndroid Build Coastguard Worker }
696*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NullptrConstructorLeavesTheModuleHolderInEmptyState)697*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
698*da0073e9SAndroid Build Coastguard Worker   A a = nullptr;
699*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(a);
700*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a.is_empty());
701*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
702*da0073e9SAndroid Build Coastguard Worker }
703*da0073e9SAndroid Build Coastguard Worker 
704*da0073e9SAndroid Build Coastguard Worker struct TestModule : public torch::nn::Module {
TestModuleTestModule705*da0073e9SAndroid Build Coastguard Worker   TestModule(int64_t size) {
706*da0073e9SAndroid Build Coastguard Worker     p1 = register_parameter("p1", torch::randn({size}));
707*da0073e9SAndroid Build Coastguard Worker     p2 = register_parameter("p2", torch::randn({size}));
708*da0073e9SAndroid Build Coastguard Worker     b1 = register_buffer("b1", torch::randn({size}));
709*da0073e9SAndroid Build Coastguard Worker     b2 = register_buffer("b2", torch::randn({size}));
710*da0073e9SAndroid Build Coastguard Worker   }
711*da0073e9SAndroid Build Coastguard Worker 
forwardTestModule712*da0073e9SAndroid Build Coastguard Worker   torch::Tensor forward(torch::Tensor input) {
713*da0073e9SAndroid Build Coastguard Worker     return input;
714*da0073e9SAndroid Build Coastguard Worker   }
715*da0073e9SAndroid Build Coastguard Worker 
716*da0073e9SAndroid Build Coastguard Worker   torch::Tensor p1, p2, b1, b2;
717*da0073e9SAndroid Build Coastguard Worker };
718*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ModulesReturnsExpectedSubmodulesForFlatModel)719*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
720*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
721*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
722*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
723*da0073e9SAndroid Build Coastguard Worker       model.ptr(), model[0], model[1], model[2]};
724*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
725*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
726*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
727*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].get(), expected[i].get());
728*da0073e9SAndroid Build Coastguard Worker   }
729*da0073e9SAndroid Build Coastguard Worker }
730*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ModulesExcludesSelfWhenIncludeSelfSetToFalse)731*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
732*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
733*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> modules =
734*da0073e9SAndroid Build Coastguard Worker       model->modules(/*include_self=*/false);
735*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
736*da0073e9SAndroid Build Coastguard Worker       model[0], model[1], model[2]};
737*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
738*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
739*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
740*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].get(), expected[i].get());
741*da0073e9SAndroid Build Coastguard Worker   }
742*da0073e9SAndroid Build Coastguard Worker }
743*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedModulesReturnsExpectedNamedSubmodulesForFlatModel)744*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
745*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
746*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
747*da0073e9SAndroid Build Coastguard Worker       model->named_modules();
748*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
749*da0073e9SAndroid Build Coastguard Worker       model.ptr(), model[0], model[1], model[2]};
750*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
751*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
752*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
753*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
754*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].value().get(), expected[i].get());
755*da0073e9SAndroid Build Coastguard Worker   }
756*da0073e9SAndroid Build Coastguard Worker }
757*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedModulesExcludesSelfWhenIncludeSelfSetToFalse)758*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
759*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
760*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
761*da0073e9SAndroid Build Coastguard Worker       model->named_modules(
762*da0073e9SAndroid Build Coastguard Worker           /*name_prefix=*/std::string(), /*include_self=*/false);
763*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
764*da0073e9SAndroid Build Coastguard Worker       model[0], model[1], model[2]};
765*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
766*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
767*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
768*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].key(), std::to_string(i));
769*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].value().get(), expected[i].get());
770*da0073e9SAndroid Build Coastguard Worker   }
771*da0073e9SAndroid Build Coastguard Worker }
772*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ChildrenReturnsExpectedSubmodulesForFlatModel)773*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
774*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
775*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
776*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
777*da0073e9SAndroid Build Coastguard Worker       model[0], model[1], model[2]};
778*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
779*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
780*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
781*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].get(), expected[i].get());
782*da0073e9SAndroid Build Coastguard Worker   }
783*da0073e9SAndroid Build Coastguard Worker 
784*da0073e9SAndroid Build Coastguard Worker   // For this flat model, this should be true.
785*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules, model->modules(/*include_self=*/false));
786*da0073e9SAndroid Build Coastguard Worker }
787*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel)788*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
789*da0073e9SAndroid Build Coastguard Worker   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
790*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
791*da0073e9SAndroid Build Coastguard Worker       model->named_children();
792*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
793*da0073e9SAndroid Build Coastguard Worker       model[0], model[1], model[2]};
794*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
795*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
796*da0073e9SAndroid Build Coastguard Worker     // Assert pointer equality.
797*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].key(), std::to_string(i));
798*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].value().get(), expected[i].get());
799*da0073e9SAndroid Build Coastguard Worker   }
800*da0073e9SAndroid Build Coastguard Worker }
801*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ParametersReturnsExpectedTensorsForFlatModel)802*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
803*da0073e9SAndroid Build Coastguard Worker   TestModule module(1);
804*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> parameters = module.parameters();
805*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters.size(), 2);
806*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
807*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
808*da0073e9SAndroid Build Coastguard Worker }
809*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedParametersReturnsExpectedTensorsForFlatModel)810*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
811*da0073e9SAndroid Build Coastguard Worker   TestModule module(1);
812*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> parameters =
813*da0073e9SAndroid Build Coastguard Worker       module.named_parameters();
814*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters.size(), 2);
815*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[0].key(), "p1");
816*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
817*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[1].key(), "p2");
818*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
819*da0073e9SAndroid Build Coastguard Worker }
820*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,BuffersReturnsExpectedTensorsForFlatModel)821*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
822*da0073e9SAndroid Build Coastguard Worker   TestModule module(1);
823*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> buffers = module.buffers();
824*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers.size(), 2);
825*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
826*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
827*da0073e9SAndroid Build Coastguard Worker }
828*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedBuffersReturnsExpectedTensorsForFlatModel)829*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
830*da0073e9SAndroid Build Coastguard Worker   TestModule module(1);
831*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, torch::Tensor> buffers =
832*da0073e9SAndroid Build Coastguard Worker       module.named_buffers();
833*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers.size(), 2);
834*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[0].key(), "b1");
835*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
836*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[1].key(), "b2");
837*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
838*da0073e9SAndroid Build Coastguard Worker }
839*da0073e9SAndroid Build Coastguard Worker 
840*da0073e9SAndroid Build Coastguard Worker struct TestContainer : torch::nn::Module {
TestContainerTestContainer841*da0073e9SAndroid Build Coastguard Worker   TestContainer(int64_t number, std::vector<TestContainer> modules = {})
842*da0073e9SAndroid Build Coastguard Worker       : tensor(torch::tensor(number)) {
843*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(modules.size())) {
844*da0073e9SAndroid Build Coastguard Worker       register_module(
845*da0073e9SAndroid Build Coastguard Worker           std::to_string(i),
846*da0073e9SAndroid Build Coastguard Worker           std::make_shared<TestContainer>(std::move(modules[i])));
847*da0073e9SAndroid Build Coastguard Worker     }
848*da0073e9SAndroid Build Coastguard Worker   }
849*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tensor;
850*da0073e9SAndroid Build Coastguard Worker };
851*da0073e9SAndroid Build Coastguard Worker 
get_test_container_item(std::shared_ptr<torch::nn::Module> module)852*da0073e9SAndroid Build Coastguard Worker int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
853*da0073e9SAndroid Build Coastguard Worker   return std::dynamic_pointer_cast<TestContainer>(module)
854*da0073e9SAndroid Build Coastguard Worker       ->tensor.item<int64_t>();
855*da0073e9SAndroid Build Coastguard Worker }
856*da0073e9SAndroid Build Coastguard Worker 
make_deeply_nested_test_container()857*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
858*da0073e9SAndroid Build Coastguard Worker   return std::make_shared<TestContainer>(TestContainer(
859*da0073e9SAndroid Build Coastguard Worker       0,
860*da0073e9SAndroid Build Coastguard Worker       {TestContainer(1, {TestContainer(2), TestContainer(3)}),
861*da0073e9SAndroid Build Coastguard Worker        TestContainer(4),
862*da0073e9SAndroid Build Coastguard Worker        TestContainer(
863*da0073e9SAndroid Build Coastguard Worker            5,
864*da0073e9SAndroid Build Coastguard Worker            {TestContainer(6),
865*da0073e9SAndroid Build Coastguard Worker             TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
866*da0073e9SAndroid Build Coastguard Worker }
867*da0073e9SAndroid Build Coastguard Worker 
868*da0073e9SAndroid Build Coastguard Worker std::vector<std::pair<std::string, int64_t>>
make_key_value_pairs_for_deeply_nested_container()869*da0073e9SAndroid Build Coastguard Worker make_key_value_pairs_for_deeply_nested_container() {
870*da0073e9SAndroid Build Coastguard Worker   return {
871*da0073e9SAndroid Build Coastguard Worker       {"test_prefix", 0},
872*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.0", 1},
873*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.0.0", 2},
874*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.0.1", 3},
875*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.1", 4},
876*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.2", 5},
877*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.2.0", 6},
878*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.2.1", 7},
879*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.2.1.0", 8},
880*da0073e9SAndroid Build Coastguard Worker       {"test_prefix.2.1.1", 9}};
881*da0073e9SAndroid Build Coastguard Worker }
882*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ModulesReturnsExpectedSubmodulesForDeepModel)883*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
884*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
885*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
886*da0073e9SAndroid Build Coastguard Worker 
887*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), 10);
888*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(modules.size())) {
889*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(get_test_container_item(modules[i]), i);
890*da0073e9SAndroid Build Coastguard Worker   }
891*da0073e9SAndroid Build Coastguard Worker }
892*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedModulesReturnsExpectedNamedSubmodulesForDeepModel)893*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
894*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
895*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
896*da0073e9SAndroid Build Coastguard Worker       model->named_modules(/*name_prefix=*/"test_prefix");
897*da0073e9SAndroid Build Coastguard Worker   auto expected = make_key_value_pairs_for_deeply_nested_container();
898*da0073e9SAndroid Build Coastguard Worker 
899*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), expected.size());
900*da0073e9SAndroid Build Coastguard Worker 
901*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(expected.size())) {
902*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(modules[i].key(), expected[i].first);
903*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
904*da0073e9SAndroid Build Coastguard Worker   }
905*da0073e9SAndroid Build Coastguard Worker }
906*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ChildrensReturnsExpectedSubmodulesForDeepModel)907*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
908*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
909*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
910*da0073e9SAndroid Build Coastguard Worker 
911*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), 3);
912*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[0]), 1);
913*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[1]), 4);
914*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[2]), 5);
915*da0073e9SAndroid Build Coastguard Worker }
916*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel)917*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
918*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
919*da0073e9SAndroid Build Coastguard Worker   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
920*da0073e9SAndroid Build Coastguard Worker       model->named_children();
921*da0073e9SAndroid Build Coastguard Worker 
922*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules.size(), 3);
923*da0073e9SAndroid Build Coastguard Worker 
924*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
925*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules[0].key(), "0");
926*da0073e9SAndroid Build Coastguard Worker 
927*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
928*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules[1].key(), "1");
929*da0073e9SAndroid Build Coastguard Worker 
930*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
931*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(modules[2].key(), "2");
932*da0073e9SAndroid Build Coastguard Worker }
933*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ModuleApplyIteratesCorreclty)934*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
935*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
936*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
937*da0073e9SAndroid Build Coastguard Worker   model->apply([&index](torch::nn::Module& module) {
938*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
939*da0073e9SAndroid Build Coastguard Worker   });
940*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
941*da0073e9SAndroid Build Coastguard Worker }
942*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ConstModuleApplyIteratesCorreclty)943*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
944*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<const TestContainer> model =
945*da0073e9SAndroid Build Coastguard Worker       make_deeply_nested_test_container();
946*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
947*da0073e9SAndroid Build Coastguard Worker   model->apply([&index](const torch::nn::Module& module) {
948*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
949*da0073e9SAndroid Build Coastguard Worker   });
950*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
951*da0073e9SAndroid Build Coastguard Worker }
952*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedModuleApplyIteratesCorreclty)953*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
954*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
955*da0073e9SAndroid Build Coastguard Worker   auto expected = make_key_value_pairs_for_deeply_nested_container();
956*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
957*da0073e9SAndroid Build Coastguard Worker   model->apply(
958*da0073e9SAndroid Build Coastguard Worker       [&index, expected](const std::string& name, torch::nn::Module& module) {
959*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(name, expected[index].first);
960*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(
961*da0073e9SAndroid Build Coastguard Worker             module.as<TestContainer>()->tensor.item<int64_t>(),
962*da0073e9SAndroid Build Coastguard Worker             expected[index++].second);
963*da0073e9SAndroid Build Coastguard Worker       },
964*da0073e9SAndroid Build Coastguard Worker       /*name_prefix=*/"test_prefix");
965*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
966*da0073e9SAndroid Build Coastguard Worker }
967*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ConstNamedModuleApplyIteratesCorreclty)968*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
969*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<const TestContainer> model =
970*da0073e9SAndroid Build Coastguard Worker       make_deeply_nested_test_container();
971*da0073e9SAndroid Build Coastguard Worker   auto expected = make_key_value_pairs_for_deeply_nested_container();
972*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
973*da0073e9SAndroid Build Coastguard Worker   model->apply(
974*da0073e9SAndroid Build Coastguard Worker       [&index, &expected](
975*da0073e9SAndroid Build Coastguard Worker           const std::string& name, const torch::nn::Module& module) {
976*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(name, expected[index].first);
977*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(
978*da0073e9SAndroid Build Coastguard Worker             module.as<const TestContainer>()->tensor.item<int64_t>(),
979*da0073e9SAndroid Build Coastguard Worker             expected[index++].second);
980*da0073e9SAndroid Build Coastguard Worker       },
981*da0073e9SAndroid Build Coastguard Worker       /*name_prefix=*/"test_prefix");
982*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
983*da0073e9SAndroid Build Coastguard Worker }
984*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ModulePointerApplyIteratesCorreclty)985*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
986*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
987*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
988*da0073e9SAndroid Build Coastguard Worker   model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
989*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(get_test_container_item(module), index++);
990*da0073e9SAndroid Build Coastguard Worker   });
991*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
992*da0073e9SAndroid Build Coastguard Worker }
993*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,NamedModulePointerApplyIteratesCorreclty)994*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
995*da0073e9SAndroid Build Coastguard Worker   auto model = make_deeply_nested_test_container();
996*da0073e9SAndroid Build Coastguard Worker   auto expected = make_key_value_pairs_for_deeply_nested_container();
997*da0073e9SAndroid Build Coastguard Worker   int64_t index = 0;
998*da0073e9SAndroid Build Coastguard Worker   model->apply(
999*da0073e9SAndroid Build Coastguard Worker       [&index, &expected](
1000*da0073e9SAndroid Build Coastguard Worker           const std::string& name,
1001*da0073e9SAndroid Build Coastguard Worker           const std::shared_ptr<torch::nn::Module>& module) {
1002*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(name, expected[index].first);
1003*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(get_test_container_item(module), expected[index++].second);
1004*da0073e9SAndroid Build Coastguard Worker       },
1005*da0073e9SAndroid Build Coastguard Worker       /*name_prefix=*/"test_prefix");
1006*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 10);
1007*da0073e9SAndroid Build Coastguard Worker }
1008*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr)1009*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
1010*da0073e9SAndroid Build Coastguard Worker   {
1011*da0073e9SAndroid Build Coastguard Worker     TestModule module(1);
1012*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
1013*da0073e9SAndroid Build Coastguard Worker         module.modules(),
1014*da0073e9SAndroid Build Coastguard Worker         "It looks like you attempted to retrieve "
1015*da0073e9SAndroid Build Coastguard Worker         "your top-level module as a shared_ptr")
1016*da0073e9SAndroid Build Coastguard Worker   }
1017*da0073e9SAndroid Build Coastguard Worker   {
1018*da0073e9SAndroid Build Coastguard Worker     TestModule module(1);
1019*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1020*da0073e9SAndroid Build Coastguard Worker     ASSERT_NO_THROW(module.modules(/*include_self=*/false));
1021*da0073e9SAndroid Build Coastguard Worker   }
1022*da0073e9SAndroid Build Coastguard Worker   {
1023*da0073e9SAndroid Build Coastguard Worker     auto module = std::make_shared<TestModule>(1);
1024*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1025*da0073e9SAndroid Build Coastguard Worker     ASSERT_NO_THROW(module->modules());
1026*da0073e9SAndroid Build Coastguard Worker   }
1027*da0073e9SAndroid Build Coastguard Worker }
1028*da0073e9SAndroid Build Coastguard Worker 
1029*da0073e9SAndroid Build Coastguard Worker struct EmptyModule : torch::nn::Module {};
1030*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,PrettyPrint)1031*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, PrettyPrint) {
1032*da0073e9SAndroid Build Coastguard Worker   struct TestModule : torch::nn::Module {
1033*da0073e9SAndroid Build Coastguard Worker     TestModule(int x, float y) : x_(x), y_(y) {}
1034*da0073e9SAndroid Build Coastguard Worker 
1035*da0073e9SAndroid Build Coastguard Worker     void pretty_print(std::ostream& stream) const override {
1036*da0073e9SAndroid Build Coastguard Worker       stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
1037*da0073e9SAndroid Build Coastguard Worker     }
1038*da0073e9SAndroid Build Coastguard Worker 
1039*da0073e9SAndroid Build Coastguard Worker     int x_;
1040*da0073e9SAndroid Build Coastguard Worker     float y_;
1041*da0073e9SAndroid Build Coastguard Worker   };
1042*da0073e9SAndroid Build Coastguard Worker 
1043*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
1044*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
1045*da0073e9SAndroid Build Coastguard Worker }
1046*da0073e9SAndroid Build Coastguard Worker 
1047*da0073e9SAndroid Build Coastguard Worker struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
forwardModuleWithNonTensorForwardImpl1048*da0073e9SAndroid Build Coastguard Worker   int64_t forward(torch::Tensor x) {
1049*da0073e9SAndroid Build Coastguard Worker     return x.numel();
1050*da0073e9SAndroid Build Coastguard Worker   }
1051*da0073e9SAndroid Build Coastguard Worker };
1052*da0073e9SAndroid Build Coastguard Worker TORCH_MODULE(ModuleWithNonTensorForward);
1053*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModuleTest,CanCallForwardOnNonTensorForwardThroughPimpl)1054*da0073e9SAndroid Build Coastguard Worker TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
1055*da0073e9SAndroid Build Coastguard Worker   ModuleWithNonTensorForward m;
1056*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(m(torch::ones(123)), 123);
1057*da0073e9SAndroid Build Coastguard Worker }
1058