xref: /aosp_15_r20/external/pytorch/test/cpp/api/module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 using namespace torch::nn;
9 using namespace torch::test;
10 
11 struct AGIUnit : torch::nn::Module {};
12 
13 namespace test {
14 struct AGIUnit : torch::nn::Module {};
15 struct AGIUnit2 : torch::nn::Module {
AGIUnit2test::AGIUnit216   AGIUnit2() : torch::nn::Module("Foo") {}
17 };
18 } // namespace test
19 
20 struct ModuleTest : torch::test::SeedingFixture {};
21 
TEST_F(ModuleTest,CanEnableAndDisableTrainingMode)22 TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
23   Linear module(3, 4);
24   ASSERT_TRUE(module->is_training());
25 
26   module->eval();
27   ASSERT_FALSE(module->is_training());
28 
29   module->train();
30   ASSERT_TRUE(module->is_training());
31 }
32 
TEST_F(ModuleTest,ZeroGrad)33 TEST_F(ModuleTest, ZeroGrad) {
34   Linear module(3, 4);
35   auto weight = torch::ones({8, 3}, torch::requires_grad());
36   auto loss = module(weight).sum();
37   loss.backward();
38   for (auto& parameter : module->parameters()) {
39     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
40     auto grad = parameter.grad();
41     ASSERT_TRUE(grad.defined());
42     ASSERT_NE(grad.sum().item<float>(), 0);
43   }
44   module->zero_grad();
45   for (auto& parameter : module->parameters()) {
46     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
47     auto grad = parameter.grad();
48     ASSERT_FALSE(grad.defined());
49   }
50 }
51 
TEST_F(ModuleTest,ZeroGradWithUndefined)52 TEST_F(ModuleTest, ZeroGradWithUndefined) {
53   struct TestModule : torch::nn::Module {
54     TestModule() {
55       x = register_parameter("x", torch::ones(5, torch::requires_grad()));
56       y = register_parameter("y", torch::ones(5, torch::requires_grad()));
57     }
58     torch::Tensor x, y;
59   };
60 
61   TestModule module;
62   auto z = module.x * 2;
63   z.sum().backward();
64 
65   ASSERT_TRUE(module.x.grad().defined());
66   ASSERT_FALSE(module.y.grad().defined());
67 
68   module.zero_grad(false); // set_to_none = false
69 
70   ASSERT_TRUE(module.x.grad().defined());
71   ASSERT_FALSE(module.y.grad().defined());
72 
73   ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
74 
75   module.zero_grad();
76 
77   ASSERT_FALSE(module.x.grad().defined());
78   ASSERT_FALSE(module.y.grad().defined());
79 }
80 
TEST_F(ModuleTest,RegisterModuleThrowsForEmptyOrDottedName)81 TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
82   struct TestModel : public torch::nn::Module {};
83   ASSERT_THROWS_WITH(
84       TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85       "Submodule name must not contain a dot (got 'name.with.dot')");
86   ASSERT_THROWS_WITH(
87       TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88       "Submodule name must not be empty");
89 }
90 
TEST_F(ModuleTest,RegisterModuleThrowsForDuplicateModuleName)91 TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92   struct TestModel : public torch::nn::Module {};
93   TestModel model;
94   model.register_module("linear", torch::nn::Linear(3, 4));
95   ASSERT_THROWS_WITH(
96       model.register_module("linear", torch::nn::Linear(3, 4)),
97       "Submodule 'linear' already defined");
98 }
99 
TEST_F(ModuleTest,ReplaceModuleThrowsForUnknownModuleName)100 TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
101   torch::nn::Module model;
102   ASSERT_THROWS_WITH(
103       model.replace_module("linear", torch::nn::Linear(3, 4)),
104       "Submodule 'linear' is not defined");
105 }
106 
TEST_F(ModuleTest,ReplaceModule)107 TEST_F(ModuleTest, ReplaceModule) {
108   struct TestModel : public torch::nn::Module {
109     torch::nn::Linear l1{nullptr};
110     TestModel() {
111       l1 = register_module("l1", torch::nn::Linear(3, 4));
112     }
113   };
114   auto model = std::make_shared<TestModel>();
115   model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
116   ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
117   ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
118 }
119 
TEST_F(ModuleTest,UnregisterModule)120 TEST_F(ModuleTest, UnregisterModule) {
121   struct TestModel : public torch::nn::Module {};
122   TestModel model;
123   ASSERT_THROWS_WITH(
124       model.unregister_module("linear"),
125       "No Module with name `linear` is registered");
126   model.register_module("linear", torch::nn::Linear(3, 4));
127   model.unregister_module("linear");
128   ASSERT_TRUE(model.children().empty());
129 }
130 
TEST_F(ModuleTest,RegisterParameterThrowsForEmptyOrDottedName)131 TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
132   struct TestModel : public torch::nn::Module {};
133   ASSERT_THROWS_WITH(
134       TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
135       "Parameter name must not contain a dot (got 'name.with.dot')");
136   ASSERT_THROWS_WITH(
137       TestModel{}.register_parameter("", torch::ones(5)),
138       "Parameter name must not be empty");
139 }
140 
TEST_F(ModuleTest,RegisterParameterThrowsForDuplicateModuleName)141 TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
142   struct TestModel : public torch::nn::Module {};
143   TestModel model;
144   model.register_parameter("p", torch::ones(5));
145   ASSERT_THROWS_WITH(
146       model.register_parameter("p", torch::ones(5)),
147       "Parameter 'p' already defined");
148 }
149 
TEST_F(ModuleTest,RegisterParameterUndefinedTensor)150 TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
151   struct TestModel : public torch::nn::Module {};
152   {
153     TestModel model;
154     model.register_parameter(
155         "undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
156     ASSERT_EQ(model.parameters().size(), 0);
157   }
158   {
159     WarningCapture warnings;
160 
161     TestModel model;
162     model.register_parameter("undefined_tensor", torch::Tensor());
163     ASSERT_EQ(model.parameters().size(), 0);
164 
165     ASSERT_EQ(
166         count_substr_occurrences(
167             warnings.str(),
168             "Ignoring the `requires_grad=true` function parameter"),
169         1);
170   }
171 }
172 
TEST_F(ModuleTest,RegisterBufferThrowsForEmptyOrDottedName)173 TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
174   struct TestModel : public torch::nn::Module {};
175   ASSERT_THROWS_WITH(
176       TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
177       "Buffer name must not contain a dot (got 'name.with.dot')");
178   ASSERT_THROWS_WITH(
179       TestModel{}.register_buffer("", torch::ones(5)),
180       "Buffer name must not be empty");
181 }
182 
TEST_F(ModuleTest,RegisterBufferThrowsForDuplicateModuleName)183 TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
184   struct TestModel : public torch::nn::Module {};
185   TestModel model;
186   model.register_buffer("p", torch::ones(5));
187   ASSERT_THROWS_WITH(
188       model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
189 }
190 
TEST_F(ModuleTest,CanGetName)191 TEST_F(ModuleTest, CanGetName) {
192   // CHECK instead of REQUIRE because demangling may fail.
193   AGIUnit agi;
194   // Call it twice just to make sure there are no bugs in the lazy
195   // initialization semantics.
196   EXPECT_EQ(agi.name(), "AGIUnit");
197   EXPECT_EQ(agi.name(), "AGIUnit");
198   EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
199   EXPECT_EQ(test::AGIUnit2().name(), "Foo");
200 }
201 
TEST_F(ModuleTest,AsCastsModulesCorrectly)202 TEST_F(ModuleTest, AsCastsModulesCorrectly) {
203   Linear module(3, 4);
204   ASSERT_EQ(module->as<Linear>(), module.get());
205   ASSERT_EQ(module->as<LinearImpl>(), module.get());
206   ASSERT_EQ(module->as<Module>(), module.get());
207   ASSERT_EQ(module->as<AGIUnit>(), nullptr);
208 
209   std::shared_ptr<Module> raw = module.ptr();
210   ASSERT_EQ(raw->as<Linear>(), module.get());
211   ASSERT_EQ(raw->as<LinearImpl>(), module.get());
212   ASSERT_EQ(raw->as<Module>(), module.get());
213   ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
214 
215   Module& raw_ref = *raw.get();
216   ASSERT_EQ(raw_ref.as<Linear>(), module.get());
217   ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
218   ASSERT_EQ(raw_ref.as<Module>(), module.get());
219   ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
220   if (auto* linear = raw_ref.as<Linear>()) {
221     ASSERT_EQ(linear->weight.ndimension(), 2);
222   }
223 
224   AGIUnit unit;
225   ASSERT_EQ(unit.as<Linear>(), nullptr);
226   ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
227   ASSERT_EQ(unit.as<AGIUnit>(), &unit);
228 }
229 
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::Device to_device,torch::Dtype to_dtype)230 void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
231     torch::Device to_device,
232     torch::Dtype to_dtype) {
233   {
234     // Case 1: Undefined tensors as parameters
235     Linear module(LinearOptions(10, 20).bias(false));
236     ASSERT_TRUE(module->weight.defined());
237     ASSERT_FALSE(module->bias.defined());
238 
239     module->to(to_device);
240     ASSERT_TRUE(module->weight.defined());
241     ASSERT_EQ(module->weight.device().type(), to_device.type());
242     ASSERT_FALSE(module->bias.defined());
243 
244     module->to(to_dtype);
245     ASSERT_TRUE(module->weight.defined());
246     ASSERT_EQ(module->weight.dtype(), to_dtype);
247     ASSERT_FALSE(module->bias.defined());
248   }
249   {
250     // Case 2: Undefined tensors as buffers
251     BatchNorm1d module(
252         BatchNorm1dOptions(5).track_running_stats(false).affine(true));
253     ASSERT_TRUE(module->weight.defined());
254     ASSERT_FALSE(module->running_mean.defined());
255 
256     module->to(to_device);
257     ASSERT_TRUE(module->weight.defined());
258     ASSERT_EQ(module->weight.device().type(), to_device.type());
259     ASSERT_FALSE(module->running_mean.defined());
260 
261     module->to(to_dtype);
262     ASSERT_TRUE(module->weight.defined());
263     ASSERT_EQ(module->weight.dtype(), to_dtype);
264     ASSERT_FALSE(module->running_mean.defined());
265   }
266 }
267 
TEST_F(ModuleTest,DeviceOrDtypeConversionSkipsUndefinedTensor)268 TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
269   test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
270 }
271 
TEST_F(ModuleTest,DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA)272 TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
273   test_DeviceOrDtypeConversionSkipsUndefinedTensor(
274       torch::kCUDA, torch::kDouble);
275 }
276 
TEST_F(ModuleTest,ParametersAndBuffersAccessorSkipsUndefinedTensor)277 TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
278   {
279     Linear module(LinearOptions(10, 20).bias(false));
280 
281     auto params = module->parameters();
282     ASSERT_EQ(params.size(), 1);
283     auto named_params = module->named_parameters();
284     ASSERT_EQ(named_params.size(), 1);
285 
286     ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
287     ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
288   }
289   {
290     BatchNorm1d module(
291         BatchNorm1dOptions(5).track_running_stats(false).affine(false));
292 
293     auto buffers = module->buffers();
294     ASSERT_EQ(buffers.size(), 0);
295     auto named_buffers = module->named_buffers();
296     ASSERT_EQ(named_buffers.size(), 0);
297   }
298   {
299     BatchNorm1d module(
300         BatchNorm1dOptions(5).track_running_stats(true).affine(false));
301 
302     auto buffers = module->buffers();
303     ASSERT_EQ(buffers.size(), 3);
304     auto named_buffers = module->named_buffers();
305     ASSERT_EQ(named_buffers.size(), 3);
306 
307     ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
308     ASSERT_TRUE(
309         pointer_equal(named_buffers["running_mean"], module->running_mean));
310     ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
311     ASSERT_TRUE(
312         pointer_equal(named_buffers["running_var"], module->running_var));
313     ASSERT_TRUE(
314         pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
315     ASSERT_TRUE(pointer_equal(
316         named_buffers["num_batches_tracked"], module->num_batches_tracked));
317   }
318 }
319 
TEST_F(ModuleTest,Conversion_MultiCUDA)320 TEST_F(ModuleTest, Conversion_MultiCUDA) {
321   Linear module(128, 64);
322   for (auto& parameter : module->parameters()) {
323     ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
324     ASSERT_EQ(parameter.dtype(), torch::kFloat32);
325   }
326   {
327     module->to({torch::kCUDA, 0});
328     for (auto& parameter : module->parameters()) {
329       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
330       ASSERT_EQ(parameter.device().index(), 0);
331     }
332     module->to({torch::kCUDA, 1});
333     for (auto& parameter : module->parameters()) {
334       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
335       ASSERT_EQ(parameter.device().index(), 1);
336     }
337   }
338   {
339     module->to(torch::Device(torch::kCPU));
340     for (auto& parameter : module->parameters()) {
341       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
342     }
343   }
344   {
345     module->to(torch::kFloat64);
346     for (auto& parameter : module->parameters()) {
347       ASSERT_EQ(parameter.dtype(), torch::kFloat64);
348     }
349   }
350 }
351 
TEST_F(ModuleTest,Conversion_NoGrad_MultiCUDA)352 TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
353   Linear module(128, 64);
354   for (auto& parameter : module->parameters()) {
355     parameter.requires_grad_(false);
356   }
357   {
358     module->to(torch::kInt32);
359     for (auto& parameter : module->parameters()) {
360       ASSERT_EQ(parameter.dtype(), torch::kInt32);
361     }
362   }
363   {
364     module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
365     for (auto& parameter : module->parameters()) {
366       ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
367       ASSERT_EQ(parameter.device().index(), 1);
368     }
369     for (auto& parameter : module->parameters()) {
370       ASSERT_EQ(parameter.dtype(), torch::kUInt8);
371     }
372   }
373 }
374 
TEST_F(ModuleTest,CallingCloneOnModuleThatDoesNotOverrideCloneThrows)375 TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
376   struct UnCloneable : Module {};
377   UnCloneable module;
378   ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
379 }
380 
TEST_F(ModuleTest,CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow)381 TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
382   struct Cloneable : Module {
383     std::shared_ptr<Module> clone(
384         const torch::optional<torch::Device>& device =
385             torch::nullopt) const override {
386       return nullptr;
387     }
388   };
389   Cloneable module;
390   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
391   ASSERT_NO_THROW({ module.clone(); });
392 }
393 
394 // NOLINTNEXTLINE(bugprone-exception-escape)
395 struct TestDistinctParametersModule
396     : public Cloneable<TestDistinctParametersModule> {
TestDistinctParametersModuleTestDistinctParametersModule397   TestDistinctParametersModule() {
398     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
399     reset();
400   }
resetTestDistinctParametersModule401   void reset() override {
402     l1 = register_module("l1", Linear(10, 3));
403     l2 = register_module("l2", Linear(3, 5));
404     l3 = register_module("l3", Linear(5, 100));
405     buffer = register_buffer("buf", torch::ones({2, 2}));
406   }
407 
408   Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
409   torch::Tensor buffer;
410 };
411 
testDistinctParameters(std::shared_ptr<Module> m1,std::shared_ptr<Module> m2)412 void testDistinctParameters(
413     std::shared_ptr<Module> m1,
414     std::shared_ptr<Module> m2) {
415   auto params1 = m1->named_parameters();
416   auto params2 = m2->named_parameters();
417   ASSERT_EQ(params1.size(), 6);
418   ASSERT_EQ(params2.size(), 6);
419   for (auto& param : params1) {
420     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
421     ASSERT_TRUE(param->allclose(params2[param.key()]));
422     param->add_(2);
423   }
424   for (auto& param : params1) {
425     ASSERT_FALSE(param->allclose(params2[param.key()]));
426   }
427 
428   auto buffers1 = m1->named_buffers();
429   auto buffers2 = m2->named_buffers();
430   ASSERT_EQ(buffers1.size(), 1);
431   ASSERT_EQ(buffers2.size(), 1);
432   for (auto& buffer : buffers1) {
433     ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
434     ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
435     buffer->add_(2);
436   }
437   for (auto& buffer : buffers1) {
438     ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
439   }
440 }
441 
TEST_F(ModuleTest,CloneCreatesDistinctParameters)442 TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
443   auto module = std::make_shared<TestDistinctParametersModule>();
444   torch::NoGradGuard no_grad;
445   auto module2 = module->clone();
446   testDistinctParameters(module, module2);
447 }
448 
TEST_F(ModuleTest,CloneCreatesDistinctParametersExplicitDevice_CUDA)449 TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
450   auto module = std::make_shared<TestDistinctParametersModule>();
451   torch::NoGradGuard no_grad;
452   torch::Device device(torch::kCUDA, 0);
453   module->to(device);
454   auto module2 = module->clone(device);
455   testDistinctParameters(module, module2);
456 }
457 
TEST_F(ModuleTest,CloneCreatesDistinctParametersExplicitDevice_MultiCUDA)458 TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
459   auto module = std::make_shared<TestDistinctParametersModule>();
460   torch::NoGradGuard no_grad;
461   torch::Device d0(torch::kCUDA, 0);
462   torch::Device d1(torch::kCUDA, 1);
463   module->to(d0);
464   auto module2 = module->clone(d1);
465 
466   for (auto& param : module->parameters()) {
467     ASSERT_EQ(param.device(), d0);
468   }
469 
470   for (auto& param : module2->parameters()) {
471     ASSERT_EQ(param.device(), d1);
472   }
473 
474   // need to move the module back to d0 as allclose expects two tensors on
475   // the same device.
476   module2->to(d0);
477   testDistinctParameters(module, module2);
478 }
479 
TEST_F(ModuleTest,ClonePreservesExternalReferences)480 TEST_F(ModuleTest, ClonePreservesExternalReferences) {
481   // NOLINTNEXTLINE(bugprone-exception-escape)
482   struct TestModule : public Cloneable<TestModule> {
483     TestModule() {
484       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
485       reset();
486     }
487     void reset() override {
488       weight = register_parameter("weight", torch::ones({4, 4}));
489     }
490     torch::Tensor weight;
491   };
492   auto module = std::make_shared<TestModule>();
493   {
494     torch::NoGradGuard no_grad;
495     module->weight += 1;
496   }
497   ASSERT_TRUE(
498       pointer_equal(module->weight, module->named_parameters()["weight"]));
499   ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
500 
501   auto module2 = std::dynamic_pointer_cast<TestModule>(
502       std::shared_ptr<Module>(module->clone()));
503   ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
504   ASSERT_TRUE(
505       pointer_equal(module2->weight, module2->named_parameters()["weight"]));
506   ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
507   ASSERT_TRUE(module2->weight.allclose(module->weight));
508   ASSERT_FALSE(
509       pointer_equal(module2->weight, module->named_parameters()["weight"]));
510 }
511 
TEST_F(ModuleTest,CloneCopiesTheValuesOfVariablesOfSubmodules)512 TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
513   // NOLINTNEXTLINE(bugprone-exception-escape)
514   struct TestModule : public Cloneable<TestModule> {
515     TestModule() {
516       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
517       reset();
518     }
519     void reset() override {
520       weight = register_parameter("weight", torch::ones({4, 4}));
521     }
522 
523     torch::Tensor weight;
524     int value = 0;
525   };
526   // NOLINTNEXTLINE(bugprone-exception-escape)
527   struct NestedModule : public Cloneable<NestedModule> {
528     NestedModule() {
529       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
530       reset();
531     }
532     void reset() override {
533       module = register_module("module", std::make_shared<TestModule>());
534     }
535     std::shared_ptr<TestModule> module;
536   };
537 
538   auto a = std::make_shared<NestedModule>();
539   {
540     torch::NoGradGuard no_grad;
541     a->module->weight += 1;
542     a->module->value = 123;
543   }
544 
545   auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
546 
547   ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
548   ASSERT_TRUE(pointer_equal(
549       b->module->weight, b->module->named_parameters()["weight"]));
550   ASSERT_TRUE(
551       b->module->named_parameters()["weight"].allclose(a->module->weight));
552   ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
553   ASSERT_EQ(b->module->value, a->module->value);
554 }
555 
TEST_F(ModuleTest,CloneToDevicePreservesTheDeviceOfParameters_CUDA)556 TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
557   // NOLINTNEXTLINE(bugprone-exception-escape)
558   struct TestModule : public Cloneable<TestModule> {
559     TestModule() {
560       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
561       reset();
562     }
563     void reset() override {
564       l1 = register_module("l1", Linear(10, 3));
565       l2 = register_module("l2", Linear(3, 5));
566       l3 = register_module("l3", Linear(5, 100));
567       buffer = register_buffer("buf", torch::ones({2, 2}));
568     }
569 
570     Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
571     torch::Tensor buffer;
572   };
573 
574   TestModule m;
575   torch::Device device(torch::kCUDA, 0);
576 
577   m.to(device);
578 
579   auto clone = m.clone();
580   for (const auto& parameter : clone->parameters()) {
581     ASSERT_EQ(parameter.device().type(), device.type());
582     ASSERT_EQ(parameter.device().index(), device.index());
583   }
584   for (const auto& buffer : clone->buffers()) {
585     ASSERT_EQ(buffer.device().type(), device.type());
586     ASSERT_EQ(buffer.device().index(), device.index());
587   }
588 }
589 
TEST_F(ModuleTest,CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA)590 TEST_F(
591     ModuleTest,
592     CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
593   // NOLINTNEXTLINE(bugprone-exception-escape)
594   struct TestModule : public Cloneable<TestModule> {
595     TestModule() {
596       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
597       reset();
598     }
599     void reset() override {
600       l1 = register_module("l1", Linear(10, 3));
601       l2 = register_module("l2", Linear(3, 5));
602       l3 = register_module("l3", Linear(5, 100));
603       buffer = register_buffer("buf", torch::ones({2, 2}));
604     }
605 
606     Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
607     torch::Tensor buffer;
608   };
609 
610   TestModule m;
611   torch::Device device(torch::kCUDA, 1);
612   // everything is on CPU here
613   auto clone = m.clone(device);
614   for (const auto& parameter : clone->parameters()) {
615     ASSERT_EQ(parameter.device().type(), device.type());
616     ASSERT_EQ(parameter.device().index(), device.index());
617   }
618   for (const auto& buffer : clone->buffers()) {
619     ASSERT_EQ(buffer.device().type(), device.type());
620     ASSERT_EQ(buffer.device().index(), device.index());
621   }
622 }
623 
624 struct ParameterTestModule : Module {
ParameterTestModuleParameterTestModule625   ParameterTestModule() {
626     a = register_parameter("a", torch::zeros({2, 2}));
627     b = register_parameter("b", torch::ones({2, 2}));
628     c = register_parameter("c", torch::ones({2, 2}) * 2);
629   }
630 
631   torch::Tensor a, b, c;
632 };
633 
TEST_F(ModuleTest,HasCorrectNumberOfParameters)634 TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
635   ParameterTestModule module;
636   ASSERT_EQ(module.parameters().size(), 3);
637   ASSERT_EQ(module.named_parameters().size(), 3);
638 }
639 
TEST_F(ModuleTest,ContainsParametersWithTheCorrectName)640 TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
641   ParameterTestModule module;
642   auto parameters = module.named_parameters();
643   ASSERT_TRUE(parameters.contains("a"));
644   ASSERT_TRUE(parameters.contains("b"));
645   ASSERT_TRUE(parameters.contains("c"));
646 }
647 
648 struct BufferTestModule : Module {
BufferTestModuleBufferTestModule649   BufferTestModule() {
650     a = register_buffer("a", torch::zeros({2, 2}));
651     b = register_buffer("b", torch::ones({2, 2}));
652     c = register_buffer("c", torch::ones({2, 2}) * 2);
653   }
654 
655   torch::Tensor a, b, c;
656 };
657 
TEST_F(ModuleTest,HasCorrectNumberOfBuffers)658 TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
659   BufferTestModule module;
660   ASSERT_EQ(module.buffers().size(), 3);
661   ASSERT_EQ(module.named_buffers().size(), 3);
662 }
663 
TEST_F(ModuleTest,ContainsBuffersWithTheCorrectName)664 TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
665   BufferTestModule module;
666   auto buffers = module.named_buffers();
667   ASSERT_TRUE(buffers.contains("a"));
668   ASSERT_TRUE(buffers.contains("b"));
669   ASSERT_TRUE(buffers.contains("c"));
670 }
671 
672 struct AImpl : torch::nn::Module {
AImplAImpl673   AImpl() : x_(123) {}
AImplAImpl674   AImpl(int x) : x_(x) {}
675   int x_;
676 };
677 TORCH_MODULE(A);
678 
TEST_F(ModuleTest,DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl)679 TEST_F(
680     ModuleTest,
681     DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
682   A a;
683   ASSERT_TRUE(a);
684   ASSERT_FALSE(a.is_empty());
685   ASSERT_EQ(a->x_, 123);
686 }
687 
TEST_F(ModuleTest,ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl)688 TEST_F(
689     ModuleTest,
690     ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
691   A a(5);
692   ASSERT_TRUE(a);
693   ASSERT_FALSE(a.is_empty());
694   ASSERT_EQ(a->x_, 5);
695 }
696 
TEST_F(ModuleTest,NullptrConstructorLeavesTheModuleHolderInEmptyState)697 TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
698   A a = nullptr;
699   ASSERT_FALSE(a);
700   ASSERT_TRUE(a.is_empty());
701   ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
702 }
703 
704 struct TestModule : public torch::nn::Module {
TestModuleTestModule705   TestModule(int64_t size) {
706     p1 = register_parameter("p1", torch::randn({size}));
707     p2 = register_parameter("p2", torch::randn({size}));
708     b1 = register_buffer("b1", torch::randn({size}));
709     b2 = register_buffer("b2", torch::randn({size}));
710   }
711 
forwardTestModule712   torch::Tensor forward(torch::Tensor input) {
713     return input;
714   }
715 
716   torch::Tensor p1, p2, b1, b2;
717 };
718 
TEST_F(ModuleTest,ModulesReturnsExpectedSubmodulesForFlatModel)719 TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
720   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
721   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
722   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
723       model.ptr(), model[0], model[1], model[2]};
724   ASSERT_EQ(modules.size(), expected.size());
725   for (const auto i : c10::irange(expected.size())) {
726     // Assert pointer equality.
727     ASSERT_EQ(modules[i].get(), expected[i].get());
728   }
729 }
730 
TEST_F(ModuleTest,ModulesExcludesSelfWhenIncludeSelfSetToFalse)731 TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
732   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
733   std::vector<std::shared_ptr<torch::nn::Module>> modules =
734       model->modules(/*include_self=*/false);
735   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
736       model[0], model[1], model[2]};
737   ASSERT_EQ(modules.size(), expected.size());
738   for (const auto i : c10::irange(expected.size())) {
739     // Assert pointer equality.
740     ASSERT_EQ(modules[i].get(), expected[i].get());
741   }
742 }
743 
TEST_F(ModuleTest,NamedModulesReturnsExpectedNamedSubmodulesForFlatModel)744 TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
745   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
746   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
747       model->named_modules();
748   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
749       model.ptr(), model[0], model[1], model[2]};
750   ASSERT_EQ(modules.size(), expected.size());
751   for (const auto i : c10::irange(expected.size())) {
752     // Assert pointer equality.
753     ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
754     ASSERT_EQ(modules[i].value().get(), expected[i].get());
755   }
756 }
757 
TEST_F(ModuleTest,NamedModulesExcludesSelfWhenIncludeSelfSetToFalse)758 TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
759   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
760   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
761       model->named_modules(
762           /*name_prefix=*/std::string(), /*include_self=*/false);
763   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
764       model[0], model[1], model[2]};
765   ASSERT_EQ(modules.size(), expected.size());
766   for (const auto i : c10::irange(expected.size())) {
767     // Assert pointer equality.
768     ASSERT_EQ(modules[i].key(), std::to_string(i));
769     ASSERT_EQ(modules[i].value().get(), expected[i].get());
770   }
771 }
772 
TEST_F(ModuleTest,ChildrenReturnsExpectedSubmodulesForFlatModel)773 TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
774   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
775   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
776   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
777       model[0], model[1], model[2]};
778   ASSERT_EQ(modules.size(), expected.size());
779   for (const auto i : c10::irange(expected.size())) {
780     // Assert pointer equality.
781     ASSERT_EQ(modules[i].get(), expected[i].get());
782   }
783 
784   // For this flat model, this should be true.
785   ASSERT_EQ(modules, model->modules(/*include_self=*/false));
786 }
787 
TEST_F(ModuleTest,NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel)788 TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
789   torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
790   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
791       model->named_children();
792   std::vector<std::shared_ptr<torch::nn::Module>> expected = {
793       model[0], model[1], model[2]};
794   ASSERT_EQ(modules.size(), expected.size());
795   for (const auto i : c10::irange(expected.size())) {
796     // Assert pointer equality.
797     ASSERT_EQ(modules[i].key(), std::to_string(i));
798     ASSERT_EQ(modules[i].value().get(), expected[i].get());
799   }
800 }
801 
TEST_F(ModuleTest,ParametersReturnsExpectedTensorsForFlatModel)802 TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
803   TestModule module(1);
804   std::vector<torch::Tensor> parameters = module.parameters();
805   ASSERT_EQ(parameters.size(), 2);
806   ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
807   ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
808 }
809 
TEST_F(ModuleTest,NamedParametersReturnsExpectedTensorsForFlatModel)810 TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
811   TestModule module(1);
812   torch::OrderedDict<std::string, torch::Tensor> parameters =
813       module.named_parameters();
814   ASSERT_EQ(parameters.size(), 2);
815   ASSERT_EQ(parameters[0].key(), "p1");
816   ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
817   ASSERT_EQ(parameters[1].key(), "p2");
818   ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
819 }
820 
TEST_F(ModuleTest,BuffersReturnsExpectedTensorsForFlatModel)821 TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
822   TestModule module(1);
823   std::vector<torch::Tensor> buffers = module.buffers();
824   ASSERT_EQ(buffers.size(), 2);
825   ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
826   ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
827 }
828 
TEST_F(ModuleTest,NamedBuffersReturnsExpectedTensorsForFlatModel)829 TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
830   TestModule module(1);
831   torch::OrderedDict<std::string, torch::Tensor> buffers =
832       module.named_buffers();
833   ASSERT_EQ(buffers.size(), 2);
834   ASSERT_EQ(buffers[0].key(), "b1");
835   ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
836   ASSERT_EQ(buffers[1].key(), "b2");
837   ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
838 }
839 
840 struct TestContainer : torch::nn::Module {
TestContainerTestContainer841   TestContainer(int64_t number, std::vector<TestContainer> modules = {})
842       : tensor(torch::tensor(number)) {
843     for (const auto i : c10::irange(modules.size())) {
844       register_module(
845           std::to_string(i),
846           std::make_shared<TestContainer>(std::move(modules[i])));
847     }
848   }
849   torch::Tensor tensor;
850 };
851 
get_test_container_item(std::shared_ptr<torch::nn::Module> module)852 int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
853   return std::dynamic_pointer_cast<TestContainer>(module)
854       ->tensor.item<int64_t>();
855 }
856 
make_deeply_nested_test_container()857 std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
858   return std::make_shared<TestContainer>(TestContainer(
859       0,
860       {TestContainer(1, {TestContainer(2), TestContainer(3)}),
861        TestContainer(4),
862        TestContainer(
863            5,
864            {TestContainer(6),
865             TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
866 }
867 
868 std::vector<std::pair<std::string, int64_t>>
make_key_value_pairs_for_deeply_nested_container()869 make_key_value_pairs_for_deeply_nested_container() {
870   return {
871       {"test_prefix", 0},
872       {"test_prefix.0", 1},
873       {"test_prefix.0.0", 2},
874       {"test_prefix.0.1", 3},
875       {"test_prefix.1", 4},
876       {"test_prefix.2", 5},
877       {"test_prefix.2.0", 6},
878       {"test_prefix.2.1", 7},
879       {"test_prefix.2.1.0", 8},
880       {"test_prefix.2.1.1", 9}};
881 }
882 
TEST_F(ModuleTest,ModulesReturnsExpectedSubmodulesForDeepModel)883 TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
884   auto model = make_deeply_nested_test_container();
885   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
886 
887   ASSERT_EQ(modules.size(), 10);
888   for (const auto i : c10::irange(modules.size())) {
889     ASSERT_EQ(get_test_container_item(modules[i]), i);
890   }
891 }
892 
TEST_F(ModuleTest,NamedModulesReturnsExpectedNamedSubmodulesForDeepModel)893 TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
894   auto model = make_deeply_nested_test_container();
895   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
896       model->named_modules(/*name_prefix=*/"test_prefix");
897   auto expected = make_key_value_pairs_for_deeply_nested_container();
898 
899   ASSERT_EQ(modules.size(), expected.size());
900 
901   for (const auto i : c10::irange(expected.size())) {
902     ASSERT_EQ(modules[i].key(), expected[i].first);
903     ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
904   }
905 }
906 
TEST_F(ModuleTest,ChildrensReturnsExpectedSubmodulesForDeepModel)907 TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
908   auto model = make_deeply_nested_test_container();
909   std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
910 
911   ASSERT_EQ(modules.size(), 3);
912   ASSERT_EQ(get_test_container_item(modules[0]), 1);
913   ASSERT_EQ(get_test_container_item(modules[1]), 4);
914   ASSERT_EQ(get_test_container_item(modules[2]), 5);
915 }
916 
TEST_F(ModuleTest,NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel)917 TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
918   auto model = make_deeply_nested_test_container();
919   torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
920       model->named_children();
921 
922   ASSERT_EQ(modules.size(), 3);
923 
924   ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
925   ASSERT_EQ(modules[0].key(), "0");
926 
927   ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
928   ASSERT_EQ(modules[1].key(), "1");
929 
930   ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
931   ASSERT_EQ(modules[2].key(), "2");
932 }
933 
TEST_F(ModuleTest,ModuleApplyIteratesCorreclty)934 TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
935   auto model = make_deeply_nested_test_container();
936   int64_t index = 0;
937   model->apply([&index](torch::nn::Module& module) {
938     ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
939   });
940   ASSERT_EQ(index, 10);
941 }
942 
TEST_F(ModuleTest,ConstModuleApplyIteratesCorreclty)943 TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
944   std::shared_ptr<const TestContainer> model =
945       make_deeply_nested_test_container();
946   int64_t index = 0;
947   model->apply([&index](const torch::nn::Module& module) {
948     ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
949   });
950   ASSERT_EQ(index, 10);
951 }
952 
TEST_F(ModuleTest,NamedModuleApplyIteratesCorreclty)953 TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
954   auto model = make_deeply_nested_test_container();
955   auto expected = make_key_value_pairs_for_deeply_nested_container();
956   int64_t index = 0;
957   model->apply(
958       [&index, expected](const std::string& name, torch::nn::Module& module) {
959         ASSERT_EQ(name, expected[index].first);
960         ASSERT_EQ(
961             module.as<TestContainer>()->tensor.item<int64_t>(),
962             expected[index++].second);
963       },
964       /*name_prefix=*/"test_prefix");
965   ASSERT_EQ(index, 10);
966 }
967 
TEST_F(ModuleTest,ConstNamedModuleApplyIteratesCorreclty)968 TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
969   std::shared_ptr<const TestContainer> model =
970       make_deeply_nested_test_container();
971   auto expected = make_key_value_pairs_for_deeply_nested_container();
972   int64_t index = 0;
973   model->apply(
974       [&index, &expected](
975           const std::string& name, const torch::nn::Module& module) {
976         ASSERT_EQ(name, expected[index].first);
977         ASSERT_EQ(
978             module.as<const TestContainer>()->tensor.item<int64_t>(),
979             expected[index++].second);
980       },
981       /*name_prefix=*/"test_prefix");
982   ASSERT_EQ(index, 10);
983 }
984 
TEST_F(ModuleTest,ModulePointerApplyIteratesCorreclty)985 TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
986   auto model = make_deeply_nested_test_container();
987   int64_t index = 0;
988   model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
989     ASSERT_EQ(get_test_container_item(module), index++);
990   });
991   ASSERT_EQ(index, 10);
992 }
993 
TEST_F(ModuleTest,NamedModulePointerApplyIteratesCorreclty)994 TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
995   auto model = make_deeply_nested_test_container();
996   auto expected = make_key_value_pairs_for_deeply_nested_container();
997   int64_t index = 0;
998   model->apply(
999       [&index, &expected](
1000           const std::string& name,
1001           const std::shared_ptr<torch::nn::Module>& module) {
1002         ASSERT_EQ(name, expected[index].first);
1003         ASSERT_EQ(get_test_container_item(module), expected[index++].second);
1004       },
1005       /*name_prefix=*/"test_prefix");
1006   ASSERT_EQ(index, 10);
1007 }
1008 
TEST_F(ModuleTest,ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr)1009 TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
1010   {
1011     TestModule module(1);
1012     ASSERT_THROWS_WITH(
1013         module.modules(),
1014         "It looks like you attempted to retrieve "
1015         "your top-level module as a shared_ptr")
1016   }
1017   {
1018     TestModule module(1);
1019     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1020     ASSERT_NO_THROW(module.modules(/*include_self=*/false));
1021   }
1022   {
1023     auto module = std::make_shared<TestModule>(1);
1024     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1025     ASSERT_NO_THROW(module->modules());
1026   }
1027 }
1028 
1029 struct EmptyModule : torch::nn::Module {};
1030 
TEST_F(ModuleTest,PrettyPrint)1031 TEST_F(ModuleTest, PrettyPrint) {
1032   struct TestModule : torch::nn::Module {
1033     TestModule(int x, float y) : x_(x), y_(y) {}
1034 
1035     void pretty_print(std::ostream& stream) const override {
1036       stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
1037     }
1038 
1039     int x_;
1040     float y_;
1041   };
1042 
1043   ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
1044   ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
1045 }
1046 
1047 struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
forwardModuleWithNonTensorForwardImpl1048   int64_t forward(torch::Tensor x) {
1049     return x.numel();
1050   }
1051 };
1052 TORCH_MODULE(ModuleWithNonTensorForward);
1053 
TEST_F(ModuleTest,CanCallForwardOnNonTensorForwardThroughPimpl)1054 TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
1055   ModuleWithNonTensorForward m;
1056   ASSERT_EQ(m(torch::ones(123)), 123);
1057 }
1058