xref: /aosp_15_r20/external/pytorch/test/cpp/api/modules.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 #include <torch/expanding_array.h>
9 #include <torch/nn/functional/activation.h>
10 #include <torch/nn/options/activation.h>
11 #include <limits>
12 #include <random>
13 
14 using namespace torch::nn;
15 using namespace torch::test;
16 
17 class TestModel : public torch::nn::Module {
18  public:
TestModel()19   TestModel()
20       : l1(register_module("l1", Linear(10, 3))),
21         l2(register_module("l2", Linear(3, 5))),
22         l3(register_module("l3", Linear(5, 100))) {}
23 
24   Linear l1, l2, l3;
25 };
26 
27 class NestedModel : public torch::nn::Module {
28  public:
NestedModel()29   NestedModel()
30       : param_(register_parameter("param", torch::empty({3, 2, 21}))),
31         l1(register_module("l1", Linear(5, 20))),
32         t(register_module("test", std::make_shared<TestModel>())) {}
33 
34   torch::Tensor param_;
35   Linear l1;
36   std::shared_ptr<TestModel> t;
37 };
38 
39 struct ModulesTest : torch::test::SeedingFixture {};
40 
TEST_F(ModulesTest,Conv1d)41 TEST_F(ModulesTest, Conv1d) {
42   Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
43   model->weight.set_data(
44       torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3}));
45   auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
46                .reshape({2, 3, 5});
47   auto y = model(x);
48   auto expected = torch::tensor(
49       {{{312., 348., 384.}, {798., 915., 1032.}},
50 
51        {{852., 888., 924.}, {2553., 2670., 2787.}}},
52       torch::kFloat);
53   ASSERT_TRUE(torch::allclose(y, expected));
54 
55   torch::Tensor s = y.sum();
56   s.backward();
57   ASSERT_EQ(s.ndimension(), 0);
58   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
59 }
60 
TEST_F(ModulesTest,Conv1dSameStrided)61 TEST_F(ModulesTest, Conv1dSameStrided) {
62   auto options = Conv1dOptions(3, 2, 3);
63   options.stride(1).padding(torch::kSame);
64   Conv1d model_valid(options);
65   ASSERT_THROWS_WITH(
66       [&] { Conv1d model_invalid(options.stride(2)); }(),
67       "padding='same' is not supported for strided convolutions");
68 }
69 
TEST_F(ModulesTest,Conv1dIvalidArg)70 TEST_F(ModulesTest, Conv1dIvalidArg) {
71   auto options = Conv1dOptions(3, 2, 3).groups(-1);
72   ASSERT_THROWS_WITH(
73       Conv1d(options), "in_channels, groups and out_channels must");
74 }
75 
TEST_F(ModulesTest,Conv2dEven)76 TEST_F(ModulesTest, Conv2dEven) {
77   Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
78   model->weight.set_data(
79       torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3}));
80   auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
81                .reshape({1, 3, 5, 5});
82   auto y = model(x);
83   auto expected = torch::tensor(
84       {{{{15219., 15570., 15921.},
85          {16974., 17325., 17676.},
86          {18729., 19080., 19431.}},
87 
88         {{37818., 38898., 39978.},
89          {43218., 44298., 45378.},
90          {48618., 49698., 50778.}}}},
91       torch::kFloat);
92   ASSERT_TRUE(torch::allclose(y, expected));
93 
94   torch::Tensor s = y.sum();
95   s.backward();
96   ASSERT_EQ(s.ndimension(), 0);
97   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
98 }
99 
TEST_F(ModulesTest,Conv2dUneven)100 TEST_F(ModulesTest, Conv2dUneven) {
101   Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
102   model->weight.set_data(
103       torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2}));
104   auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
105                .reshape({1, 3, 5, 4});
106   auto y = model(x);
107   auto expected = torch::tensor(
108       {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
109 
110         {{13227., 13704., 14181.},
111          {15135., 15612., 16089.},
112          {17043., 17520., 17997.}}}},
113       torch::kFloat);
114   ASSERT_TRUE(torch::allclose(y, expected));
115 
116   torch::Tensor s = y.sum();
117   s.backward();
118   ASSERT_EQ(s.ndimension(), 0);
119   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
120 }
121 
TEST_F(ModulesTest,Conv2dSameStrided)122 TEST_F(ModulesTest, Conv2dSameStrided) {
123   auto options = Conv2dOptions(3, 2, {3, 4});
124   options.stride(1).padding(torch::kSame);
125   Conv2d model_valid(options);
126   ASSERT_THROWS_WITH(
127       [&] { Conv2d model_invalid(options.stride(2)); }(),
128       "padding='same' is not supported for strided convolutions");
129   ASSERT_THROWS_WITH(
130       [&] {
131         Conv2d model_invalid(options.stride({1, 2}));
132       }(),
133       "padding='same' is not supported for strided convolutions");
134 }
135 
TEST_F(ModulesTest,Conv3d)136 TEST_F(ModulesTest, Conv3d) {
137   Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
138   model->weight.set_data(
139       torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3}));
140   auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
141                .reshape({1, 3, 5, 5, 5});
142   auto y = model(x);
143   auto expected = torch::tensor(
144       {{{{{700704., 703944., 707184.},
145           {716904., 720144., 723384.},
146           {733104., 736344., 739584.}},
147 
148          {{781704., 784944., 788184.},
149           {797904., 801144., 804384.},
150           {814104., 817344., 820584.}},
151 
152          {{862704., 865944., 869184.},
153           {878904., 882144., 885384.},
154           {895104., 898344., 901584.}}},
155 
156         {{{1724220., 1734021., 1743822.},
157           {1773225., 1783026., 1792827.},
158           {1822230., 1832031., 1841832.}},
159 
160          {{1969245., 1979046., 1988847.},
161           {2018250., 2028051., 2037852.},
162           {2067255., 2077056., 2086857.}},
163 
164          {{2214270., 2224071., 2233872.},
165           {2263275., 2273076., 2282877.},
166           {2312280., 2322081., 2331882.}}}}},
167       torch::kFloat);
168   ASSERT_TRUE(torch::allclose(y, expected));
169 
170   torch::Tensor s = y.sum();
171   s.backward();
172   ASSERT_EQ(s.ndimension(), 0);
173   ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
174 }
175 
TEST_F(ModulesTest,Conv3dSameStrided)176 TEST_F(ModulesTest, Conv3dSameStrided) {
177   auto options = Conv3dOptions(3, 2, {3, 4, 5});
178   options.stride(1).padding(torch::kSame);
179   Conv3d model_valid(options);
180   ASSERT_THROWS_WITH(
181       [&] { Conv3d model_invalid(options.stride(2)); }(),
182       "padding='same' is not supported for strided convolutions");
183   ASSERT_THROWS_WITH(
184       [&] {
185         Conv3d model_invalid(options.stride({1, 2, 1}));
186       }(),
187       "padding='same' is not supported for strided convolutions");
188 }
189 
TEST_F(ModulesTest,ConvTranspose1d)190 TEST_F(ModulesTest, ConvTranspose1d) {
191   ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
192   model->weight.set_data(torch::arange(18.).view({2, 3, 3}));
193   auto x = torch::arange(20.).reshape({2, 2, 5});
194   auto y = model(x);
195   auto expected = torch::tensor(
196       {{{45., 104., 179., 212., 245., 188., 107.},
197         {60., 140., 242., 293., 344., 260., 146.},
198         {75., 176., 305., 374., 443., 332., 185.}},
199        {{135., 304., 509., 542., 575., 428., 237.},
200         {210., 460., 752., 803., 854., 620., 336.},
201         {285., 616., 995., 1064., 1133., 812., 435.}}});
202   ASSERT_TRUE(torch::allclose(y, expected));
203 
204   torch::Tensor s = y.sum();
205   s.backward();
206   ASSERT_EQ(s.ndimension(), 0);
207   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
208 }
209 
TEST_F(ModulesTest,ConvTranspose2dEven)210 TEST_F(ModulesTest, ConvTranspose2dEven) {
211   ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false));
212   model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3}));
213   auto x = torch::arange(50.).view({1, 2, 5, 5});
214   auto y = model(x);
215   auto expected = torch::tensor(
216       {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
217          {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
218          {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
219          {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
220          {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
221          {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
222          {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
223         {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
224          {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
225          {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
226          {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
227          {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
228          {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
229          {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
230         {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
231          {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
232          {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
233          {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
234          {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
235          {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
236          {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
237   ASSERT_TRUE(torch::allclose(y, expected));
238 
239   torch::Tensor s = y.sum();
240   s.backward();
241   ASSERT_EQ(s.ndimension(), 0);
242   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
243 }
244 
TEST_F(ModulesTest,ConvTranspose2dUneven)245 TEST_F(ModulesTest, ConvTranspose2dUneven) {
246   ConvTranspose2d model(
247       ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
248   model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2}));
249   auto x = torch::arange(40.).view({1, 2, 5, 4});
250   auto y = model(x);
251   auto expected = torch::tensor(
252       {{{{360., 758., 796., 834., 440.},
253          {832., 1752., 1836., 1920., 1012.},
254          {1432., 3014., 3152., 3290., 1732.},
255          {1696., 3566., 3704., 3842., 2020.},
256          {1960., 4118., 4256., 4394., 2308.},
257          {1504., 3152., 3252., 3352., 1756.},
258          {856., 1790., 1844., 1898., 992.}},
259         {{480., 1010., 1072., 1134., 596.},
260          {1120., 2352., 2484., 2616., 1372.},
261          {1936., 4058., 4268., 4478., 2344.},
262          {2344., 4898., 5108., 5318., 2776.},
263          {2752., 5738., 5948., 6158., 3208.},
264          {2080., 4328., 4476., 4624., 2404.},
265          {1168., 2426., 2504., 2582., 1340.}},
266         {{600., 1262., 1348., 1434., 752.},
267          {1408., 2952., 3132., 3312., 1732.},
268          {2440., 5102., 5384., 5666., 2956.},
269          {2992., 6230., 6512., 6794., 3532.},
270          {3544., 7358., 7640., 7922., 4108.},
271          {2656., 5504., 5700., 5896., 3052.},
272          {1480., 3062., 3164., 3266., 1688.}}}});
273   ASSERT_TRUE(torch::allclose(y, expected));
274 
275   torch::Tensor s = y.sum();
276   s.backward();
277   ASSERT_EQ(s.ndimension(), 0);
278   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
279 }
280 
TEST_F(ModulesTest,ConvTranspose3d)281 TEST_F(ModulesTest, ConvTranspose3d) {
282   ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false));
283   model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
284   auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
285   auto y = model(x);
286   auto expected = torch::tensor(
287       {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
288          {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
289          {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
290         {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
291          {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
292          {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
293   ASSERT_TRUE(torch::allclose(y, expected));
294 
295   torch::Tensor s = y.sum();
296   s.backward();
297   ASSERT_EQ(s.ndimension(), 0);
298   ASSERT_TRUE(model->weight.grad().numel() == 2 * 2 * 2 * 2 * 2);
299 }
300 
TEST_F(ModulesTest,MaxPool1d)301 TEST_F(ModulesTest, MaxPool1d) {
302   MaxPool1d model(MaxPool1dOptions(3).stride(2));
303   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
304   auto y = model(x);
305   torch::Tensor s = y.sum();
306 
307   s.backward();
308   ASSERT_EQ(y.ndimension(), 3);
309   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
310   ASSERT_EQ(s.ndimension(), 0);
311   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
312 }
313 
TEST_F(ModulesTest,MaxPool1dReturnIndices)314 TEST_F(ModulesTest, MaxPool1dReturnIndices) {
315   MaxPool1d model(MaxPool1dOptions(3).stride(2));
316   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
317   auto [y, indices] = model->forward_with_indices(x);
318 
319   ASSERT_EQ(y.dim(), 3);
320   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
321   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
322 
323   ASSERT_TRUE(
324       torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong)));
325   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 2}));
326 }
327 
TEST_F(ModulesTest,MaxPool2dEven)328 TEST_F(ModulesTest, MaxPool2dEven) {
329   MaxPool2d model(MaxPool2dOptions(3).stride(2));
330   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
331   auto y = model(x);
332   torch::Tensor s = y.sum();
333 
334   s.backward();
335   ASSERT_EQ(y.ndimension(), 3);
336   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
337   ASSERT_EQ(s.ndimension(), 0);
338   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
339 }
340 
TEST_F(ModulesTest,MaxPool2dUneven)341 TEST_F(ModulesTest, MaxPool2dUneven) {
342   MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
343   auto x = torch::ones({2, 5, 4}, torch::requires_grad());
344   auto y = model(x);
345   torch::Tensor s = y.sum();
346 
347   s.backward();
348   ASSERT_EQ(y.ndimension(), 3);
349   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
350   ASSERT_EQ(s.ndimension(), 0);
351   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
352 }
353 
TEST_F(ModulesTest,MaxPool2dReturnIndices)354 TEST_F(ModulesTest, MaxPool2dReturnIndices) {
355   MaxPool2d model(MaxPool2dOptions(3).stride(2));
356   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
357   auto [y, indices] = model->forward_with_indices(x);
358 
359   ASSERT_EQ(y.dim(), 3);
360   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
361   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
362   ASSERT_TRUE(torch::allclose(
363       indices,
364       torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong)));
365   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
366 }
367 
TEST_F(ModulesTest,MaxPool3d)368 TEST_F(ModulesTest, MaxPool3d) {
369   MaxPool3d model(MaxPool3dOptions(3).stride(2));
370   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
371   auto y = model(x);
372   torch::Tensor s = y.sum();
373 
374   s.backward();
375   ASSERT_EQ(y.ndimension(), 4);
376   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
377   ASSERT_EQ(s.ndimension(), 0);
378   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
379 }
380 
TEST_F(ModulesTest,MaxPool3dReturnIndices)381 TEST_F(ModulesTest, MaxPool3dReturnIndices) {
382   MaxPool3d model(MaxPool3dOptions(3).stride(2));
383   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
384   auto [y, indices] = model->forward_with_indices(x);
385 
386   ASSERT_EQ(y.dim(), 4);
387   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
388   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
389 
390   ASSERT_TRUE(torch::allclose(
391       indices,
392       torch::tensor(
393           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
394            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}},
395           torch::kLong)));
396   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
397 }
398 
TEST_F(ModulesTest,AvgPool1d)399 TEST_F(ModulesTest, AvgPool1d) {
400   AvgPool1d model(AvgPool1dOptions(3).stride(2));
401   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
402   auto y = model(x);
403   torch::Tensor s = y.sum();
404 
405   s.backward();
406   ASSERT_EQ(y.ndimension(), 3);
407   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
408   ASSERT_EQ(s.ndimension(), 0);
409   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
410 }
411 
TEST_F(ModulesTest,AvgPool2dEven)412 TEST_F(ModulesTest, AvgPool2dEven) {
413   AvgPool2d model(AvgPool2dOptions(3).stride(2));
414   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
415   auto y = model(x);
416   torch::Tensor s = y.sum();
417 
418   s.backward();
419   ASSERT_EQ(y.ndimension(), 3);
420   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
421   ASSERT_EQ(s.ndimension(), 0);
422   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
423 }
424 
TEST_F(ModulesTest,AvgPool2dUneven)425 TEST_F(ModulesTest, AvgPool2dUneven) {
426   AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2}));
427   auto x = torch::ones({2, 5, 4}, torch::requires_grad());
428   auto y = model(x);
429   torch::Tensor s = y.sum();
430 
431   s.backward();
432   ASSERT_EQ(y.ndimension(), 3);
433   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
434   ASSERT_EQ(s.ndimension(), 0);
435   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
436 }
437 
TEST_F(ModulesTest,AvgPool3d)438 TEST_F(ModulesTest, AvgPool3d) {
439   AvgPool3d model(AvgPool3dOptions(3).stride(2));
440   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
441   auto y = model(x);
442   torch::Tensor s = y.sum();
443 
444   s.backward();
445   ASSERT_EQ(y.ndimension(), 4);
446   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
447   ASSERT_EQ(s.ndimension(), 0);
448   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
449 }
450 
TEST_F(ModulesTest,FractionalMaxPool2d)451 TEST_F(ModulesTest, FractionalMaxPool2d) {
452   FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
453   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
454   auto y = model(x);
455   torch::Tensor s = y.sum();
456 
457   s.backward();
458   ASSERT_EQ(y.ndimension(), 3);
459   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
460   ASSERT_EQ(s.ndimension(), 0);
461   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
462 }
463 
TEST_F(ModulesTest,FractionalMaxPool2dReturnIndices)464 TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) {
465   FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
466   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
467   auto [y, indices] = model->forward_with_indices(x);
468 
469   ASSERT_EQ(y.dim(), 3);
470   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
471   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
472   ASSERT_TRUE(torch::allclose(
473       indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
474   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
475 }
476 
TEST_F(ModulesTest,FractionalMaxPool3d)477 TEST_F(ModulesTest, FractionalMaxPool3d) {
478   FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
479   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
480   auto y = model(x);
481   torch::Tensor s = y.sum();
482 
483   s.backward();
484   ASSERT_EQ(y.ndimension(), 4);
485   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
486   ASSERT_EQ(s.ndimension(), 0);
487   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
488 }
489 
TEST_F(ModulesTest,FractionalMaxPool3dReturnIndices)490 TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) {
491   FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
492   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
493   auto [y, indices] = model->forward_with_indices(x);
494 
495   ASSERT_EQ(y.dim(), 4);
496   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
497   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
498 
499   ASSERT_TRUE(torch::allclose(
500       indices,
501       torch::tensor(
502           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
503            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
504   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
505 }
506 
TEST_F(ModulesTest,LPPool1d)507 TEST_F(ModulesTest, LPPool1d) {
508   int norm_type = 2;
509   int stride = 2;
510   int kernel_size = 3;
511 
512   LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride));
513   auto x = torch::ones({1, 1, 5});
514   auto y = model(x);
515   auto expected =
516       (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
517        kernel_size)
518           .pow(1. / norm_type);
519 
520   ASSERT_EQ(y.ndimension(), 3);
521   ASSERT_TRUE(torch::allclose(y, expected));
522   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
523 }
524 
TEST_F(ModulesTest,LPPool2d)525 TEST_F(ModulesTest, LPPool2d) {
526   int norm_type = 2;
527   int stride = 2;
528   std::vector<int64_t> kernel_size({2, 3});
529 
530   LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
531   auto x = torch::ones({1, 1, 2, 5});
532   auto y = model(x);
533   auto expected =
534       (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
535        (kernel_size[0] * kernel_size[1]))
536           .pow(1. / norm_type);
537 
538   ASSERT_EQ(y.ndimension(), 4);
539   ASSERT_TRUE(torch::allclose(y, expected));
540   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
541 }
542 
TEST_F(ModulesTest,LPPool3d)543 TEST_F(ModulesTest, LPPool3d) {
544   int norm_type = 2;
545   int stride = 2;
546   std::vector<int64_t> kernel_size({1, 2, 3});
547 
548   LPPool3d model(LPPool3dOptions(norm_type, kernel_size).stride(stride));
549   auto x = torch::ones({1, 1, 1, 2, 5});
550   auto y = model(x);
551   auto expected =
552       (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
553        (kernel_size[0] * kernel_size[1] * kernel_size[2]))
554           .pow(1. / norm_type);
555 
556   ASSERT_EQ(y.ndimension(), 5);
557   ASSERT_TRUE(torch::allclose(y, expected));
558   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
559 }
560 
TEST_F(ModulesTest,Identity)561 TEST_F(ModulesTest, Identity) {
562   Identity identity;
563   auto input = torch::tensor(
564       {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
565   auto output = identity->forward(input);
566   auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
567   auto s = output.sum();
568   s.backward();
569 
570   ASSERT_TRUE(torch::equal(output, expected));
571   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
572 }
573 
TEST_F(ModulesTest,Flatten)574 TEST_F(ModulesTest, Flatten) {
575   Flatten flatten;
576   auto input = torch::tensor(
577       {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
578   auto output = flatten->forward(input);
579   auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat);
580   auto s = output.sum();
581 
582   s.backward();
583   ASSERT_TRUE(torch::equal(output, expected));
584   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
585 
586   // Testing with optional arguments start_dim and end_dim
587   Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3));
588   input = torch::tensor(
589       {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
590        {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}},
591       torch::dtype(torch::kFloat)
592           .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2)
593 
594   output = flatten_optional_dims->forward(input);
595   expected = torch::tensor(
596       {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}},
597       torch::kFloat); // Tensor with sizes (2, 2, 4)
598 
599   s = output.sum();
600   s.backward();
601   ASSERT_TRUE(torch::equal(output, expected));
602   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
603 }
604 
TEST_F(ModulesTest,Unflatten)605 TEST_F(ModulesTest, Unflatten) {
606   // Non-named tensor
607   Unflatten unflatten(UnflattenOptions(0, {2, 2}));
608   auto output = unflatten->forward(torch::tensor({1, 2, 3, 4}));
609   auto expected = torch::tensor({{1, 2}, {3, 4}});
610   ASSERT_TRUE(torch::equal(output, expected));
611 
612   // Named tensor
613   auto make_dimnames = [](std::vector<std::string> names) {
614     std::vector<torch::Dimname> dimnames;
615     // NOLINTNEXTLINE(performance-for-range-copy)
616     for (auto name : names) {
617       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
618       dimnames.push_back(
619           torch::Dimname::fromSymbol(torch::Symbol::dimname(name)));
620     }
621     return dimnames;
622   };
623 
624   unflatten = Unflatten(UnflattenOptions(
625       "B",
626       {std::pair<std::string, int64_t>{"B1", 2},
627        std::pair<std::string, int64_t>{"B2", 2}}));
628   output = unflatten->forward(
629       torch::tensor({{1, 2, 3, 4}}).refine_names(make_dimnames({"A", "B"})));
630   expected = torch::tensor({{{1, 2}, {3, 4}}})
631                  .refine_names(make_dimnames({"A", "B1", "B2"}));
632   ASSERT_TRUE(torch::equal(output, expected));
633 }
634 
TEST_F(ModulesTest,AdaptiveMaxPool1d)635 TEST_F(ModulesTest, AdaptiveMaxPool1d) {
636   AdaptiveMaxPool1d model(3);
637   auto x = torch::tensor(
638       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
639   auto y = model(x);
640   torch::Tensor s = y.sum();
641 
642   s.backward();
643   ASSERT_EQ(y.ndimension(), 3);
644   ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
645   ASSERT_EQ(s.ndimension(), 0);
646   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
647 }
648 
TEST_F(ModulesTest,AdaptiveMaxPool1dReturnIndices)649 TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) {
650   AdaptiveMaxPool1d model(3);
651   auto x = torch::tensor(
652       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
653   auto [y, indices] = model->forward_with_indices(x);
654 
655   ASSERT_EQ(y.dim(), 3);
656   ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
657   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
658   ASSERT_TRUE(
659       torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong)));
660   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 3}));
661 }
662 
TEST_F(ModulesTest,AdaptiveMaxPool2dEven)663 TEST_F(ModulesTest, AdaptiveMaxPool2dEven) {
664   AdaptiveMaxPool2d model(3);
665   auto x = torch::arange(0., 50);
666   x.resize_({2, 5, 5}).set_requires_grad(true);
667   auto y = model(x);
668   torch::Tensor s = y.sum();
669 
670   s.backward();
671   ASSERT_EQ(y.ndimension(), 3);
672   ASSERT_TRUE(torch::allclose(
673       y,
674       torch::tensor(
675           {
676               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
677               {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
678           },
679           torch::kFloat)));
680   ASSERT_EQ(s.ndimension(), 0);
681   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
682 }
683 
TEST_F(ModulesTest,AdaptiveMaxPool2dUneven)684 TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) {
685   AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
686   auto x = torch::arange(0., 40);
687   x.resize_({2, 5, 4}).set_requires_grad(true);
688   auto y = model(x);
689   torch::Tensor s = y.sum();
690 
691   s.backward();
692   ASSERT_EQ(y.ndimension(), 3);
693   ASSERT_TRUE(torch::allclose(
694       y,
695       torch::tensor(
696           {
697               {{5, 7}, {13, 15}, {17, 19}},
698               {{25, 27}, {33, 35}, {37, 39}},
699           },
700           torch::kFloat)));
701   ASSERT_EQ(s.ndimension(), 0);
702   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
703 }
704 
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesEven)705 TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) {
706   AdaptiveMaxPool2d model(3);
707   auto x = torch::arange(0., 50);
708   x.resize_({2, 5, 5}).set_requires_grad(true);
709   auto [y, indices] = model->forward_with_indices(x);
710   torch::Tensor s = y.sum();
711 
712   s.backward();
713   ASSERT_EQ(s.ndimension(), 0);
714 
715   ASSERT_EQ(y.ndimension(), 3);
716   ASSERT_TRUE(torch::allclose(
717       y,
718       torch::tensor(
719           {
720               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
721               {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
722           },
723           torch::kFloat)));
724   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
725 
726   ASSERT_EQ(indices.ndimension(), 3);
727   ASSERT_TRUE(torch::allclose(
728       indices,
729       torch::tensor(
730           {
731               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
732               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
733           },
734           torch::kLong)));
735   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 3}));
736 }
737 
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesUneven)738 TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) {
739   AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
740   auto x = torch::arange(0., 40);
741   x.resize_({2, 5, 4}).set_requires_grad(true);
742   auto [y, indices] = model->forward_with_indices(x);
743   torch::Tensor s = y.sum();
744 
745   s.backward();
746   ASSERT_EQ(s.ndimension(), 0);
747 
748   ASSERT_EQ(y.ndimension(), 3);
749   ASSERT_TRUE(torch::allclose(
750       y,
751       torch::tensor(
752           {
753               {{5, 7}, {13, 15}, {17, 19}},
754               {{25, 27}, {33, 35}, {37, 39}},
755           },
756           torch::kFloat)));
757   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
758 
759   ASSERT_EQ(indices.ndimension(), 3);
760   ASSERT_TRUE(torch::allclose(
761       indices,
762       torch::tensor(
763           {
764               {{5, 7}, {13, 15}, {17, 19}},
765               {{5, 7}, {13, 15}, {17, 19}},
766           },
767           torch::kLong)));
768   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 2}));
769 }
770 
TEST_F(ModulesTest,AdaptiveMaxPool3d)771 TEST_F(ModulesTest, AdaptiveMaxPool3d) {
772   AdaptiveMaxPool3d model(3);
773   auto x = torch::arange(0., 64);
774   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
775   auto y = model(x);
776   torch::Tensor s = y.sum();
777 
778   s.backward();
779   ASSERT_EQ(s.ndimension(), 0);
780 
781   ASSERT_EQ(y.ndimension(), 4);
782   ASSERT_TRUE(torch::allclose(
783       y,
784       torch::tensor(
785           {
786               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
787               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
788               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
789           },
790           torch::kFloat)));
791   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
792 }
793 
TEST_F(ModulesTest,AdaptiveMaxPool3dReturnIndices)794 TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) {
795   AdaptiveMaxPool3d model(3);
796   auto x = torch::arange(0., 64);
797   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
798   auto [y, indices] = model->forward_with_indices(x);
799   torch::Tensor s = y.sum();
800 
801   s.backward();
802   ASSERT_EQ(s.ndimension(), 0);
803 
804   ASSERT_EQ(y.ndimension(), 4);
805   ASSERT_TRUE(torch::allclose(
806       y,
807       torch::tensor(
808           {
809               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
810               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
811               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
812           },
813           torch::kFloat)));
814   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
815 
816   ASSERT_EQ(indices.ndimension(), 4);
817   ASSERT_TRUE(torch::allclose(
818       indices,
819       torch::tensor(
820           {
821               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
822               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
823               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
824           },
825           torch::kLong)));
826   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
827 }
828 
TEST_F(ModulesTest,AdaptiveAvgPool1d)829 TEST_F(ModulesTest, AdaptiveAvgPool1d) {
830   AdaptiveAvgPool1d model(3);
831   auto x = torch::tensor(
832       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
833   auto y = model(x);
834   torch::Tensor s = y.sum();
835 
836   s.backward();
837   ASSERT_EQ(s.ndimension(), 0);
838 
839   ASSERT_EQ(y.ndimension(), 3);
840   ASSERT_TRUE(
841       torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat)));
842   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
843 }
844 
TEST_F(ModulesTest,AdaptiveAvgPool2dEven)845 TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
846   AdaptiveAvgPool2d model(3);
847   auto x = torch::arange(0., 50);
848   x.resize_({2, 5, 5}).set_requires_grad(true);
849   auto y = model(x);
850   torch::Tensor s = y.sum();
851 
852   s.backward();
853   ASSERT_EQ(s.ndimension(), 0);
854 
855   ASSERT_EQ(y.ndimension(), 3);
856   ASSERT_TRUE(torch::allclose(
857       y,
858       torch::tensor(
859           {
860               {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}},
861               {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}},
862           },
863           torch::kFloat)));
864   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
865 }
866 
TEST_F(ModulesTest,AdaptiveAvgPool2dUneven)867 TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
868   AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
869   auto x = torch::arange(0., 40);
870   x.resize_({2, 5, 4}).set_requires_grad(true);
871   auto y = model(x);
872   torch::Tensor s = y.sum();
873 
874   s.backward();
875   ASSERT_EQ(s.ndimension(), 0);
876 
877   ASSERT_EQ(y.ndimension(), 3);
878   ASSERT_TRUE(torch::allclose(
879       y,
880       torch::tensor(
881           {
882               {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}},
883               {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}},
884           },
885           torch::kFloat)));
886   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
887 }
888 
TEST_F(ModulesTest,AdaptiveAvgPool3d)889 TEST_F(ModulesTest, AdaptiveAvgPool3d) {
890   AdaptiveAvgPool3d model(3);
891   auto x = torch::arange(0., 64);
892   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
893   auto y = model(x);
894   torch::Tensor s = y.sum();
895 
896   s.backward();
897   ASSERT_EQ(s.ndimension(), 0);
898 
899   ASSERT_EQ(y.ndimension(), 4);
900   ASSERT_TRUE(torch::allclose(
901       y,
902       torch::tensor(
903           {
904               {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}},
905               {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}},
906               {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}},
907           },
908           torch::kFloat)));
909   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
910 }
911 
TEST_F(ModulesTest,MaxUnpool1d)912 TEST_F(ModulesTest, MaxUnpool1d) {
913   auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
914   auto x = torch::tensor(
915       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
916   auto model = MaxUnpool1d{3};
917   auto y = model->forward(x, indices);
918 
919   ASSERT_EQ(y.dim(), 3);
920   ASSERT_TRUE(torch::allclose(
921       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
922   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
923 
924   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
925   x = torch::tensor(
926       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
927   model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)};
928   y = model->forward(x, indices, std::vector<int64_t>({1, 1, 5}));
929 
930   ASSERT_EQ(y.dim(), 3);
931   ASSERT_TRUE(
932       torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
933   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
934 }
935 
TEST_F(ModulesTest,MaxPool1d_MaxUnpool1d)936 TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) {
937   MaxPool1d pool{MaxPool1dOptions(2).stride(2)};
938   MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)};
939   auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat);
940   auto [output, indices] = pool->forward_with_indices(input);
941   ASSERT_TRUE(torch::allclose(
942       unpool(output, indices),
943       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
944 
945   // Example showcasing the use of output_size
946   input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat);
947   std::tie(output, indices) = pool->forward_with_indices(input);
948   ASSERT_TRUE(torch::allclose(
949       unpool(output, indices, input.sizes().vec()),
950       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat)));
951   ASSERT_TRUE(torch::allclose(
952       unpool(output, indices),
953       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
954 }
955 
TEST_F(ModulesTest,MaxUnpool2d)956 TEST_F(ModulesTest, MaxUnpool2d) {
957   auto indices = torch::tensor(
958       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
959        {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
960       torch::kLong);
961   auto x = torch::tensor(
962       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
963        {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
964       torch::dtype(torch::kFloat).requires_grad(true));
965   auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)};
966   auto y = model->forward(x, indices);
967 
968   ASSERT_EQ(y.dim(), 4);
969   ASSERT_TRUE(torch::allclose(
970       y,
971       torch::tensor(
972           {{{{0, 0, 0, 0, 0},
973              {0, 6, 0, 8, 9},
974              {0, 0, 0, 0, 0},
975              {0, 16, 0, 18, 19},
976              {0, 21, 0, 23, 24}}},
977            {{{0, 0, 0, 0, 0},
978              {0, 31, 0, 33, 34},
979              {0, 0, 0, 0, 0},
980              {0, 41, 0, 43, 44},
981              {0, 46, 0, 48, 49}}}},
982           torch::kFloat)));
983   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
984 }
985 
TEST_F(ModulesTest,MaxPool2d_MaxUnpool2d)986 TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) {
987   MaxPool2d pool{MaxPool2dOptions(2).stride(2)};
988   MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)};
989   auto input = torch::tensor(
990       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}},
991       torch::kFloat);
992   auto [output, indices] = pool->forward_with_indices(input);
993   ASSERT_TRUE(torch::allclose(
994       unpool(output, indices),
995       torch::tensor(
996           {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}},
997           torch::kFloat)));
998 
999   ASSERT_TRUE(torch::allclose(
1000       unpool(output, indices, std::vector<int64_t>{1, 1, 5, 5}),
1001       torch::tensor(
1002           {{{{0, 0, 0, 0, 0},
1003              {6, 0, 8, 0, 0},
1004              {0, 0, 0, 14, 0},
1005              {16, 0, 0, 0, 0},
1006              {0, 0, 0, 0, 0}}}},
1007           torch::kFloat)));
1008 }
1009 
TEST_F(ModulesTest,MaxUnpool3d)1010 TEST_F(ModulesTest, MaxUnpool3d) {
1011   auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1012   auto x = torch::tensor(
1013       {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1014   auto model = MaxUnpool3d{3};
1015   auto y = model->forward(x, indices);
1016 
1017   ASSERT_EQ(y.dim(), 5);
1018   ASSERT_TRUE(torch::allclose(
1019       y,
1020       torch::tensor(
1021           {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1022              {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1023              {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1024           torch::kFloat)));
1025   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1026 }
1027 
TEST_F(ModulesTest,MaxUnpool3dOutputSize)1028 TEST_F(ModulesTest, MaxUnpool3dOutputSize) {
1029   auto indices = torch::tensor(
1030       {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong);
1031   auto x = torch::tensor(
1032       {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}},
1033       torch::dtype(torch::kFloat).requires_grad(true));
1034   auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)};
1035   auto y = model->forward(x, indices, std::vector<int64_t>({1, 1, 4, 4, 4}));
1036 
1037   ASSERT_EQ(y.dim(), 5);
1038   ASSERT_TRUE(torch::allclose(
1039       y,
1040       torch::tensor(
1041           {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1042              {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}},
1043              {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1044              {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}},
1045           torch::kFloat)));
1046   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 4, 4, 4}));
1047 }
1048 
TEST_F(ModulesTest,MaxPool3d_MaxUnpool3d)1049 TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) {
1050   MaxPool3d pool{MaxPool3dOptions(3).stride(2)};
1051   MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)};
1052   auto input = torch::randn({20, 16, 51, 33, 15});
1053   auto [output, indices] = pool->forward_with_indices(input);
1054   auto unpooled_output = unpool(output, indices);
1055   ASSERT_EQ(
1056       unpooled_output.sizes(), std::vector<int64_t>({20, 16, 51, 33, 15}));
1057 }
1058 
TEST_F(ModulesTest,Linear)1059 TEST_F(ModulesTest, Linear) {
1060   {
1061     Linear model(5, 2);
1062     auto x = torch::randn({10, 5}, torch::requires_grad());
1063     auto y = model(x);
1064     torch::Tensor s = y.sum();
1065 
1066     s.backward();
1067     ASSERT_EQ(y.ndimension(), 2);
1068     ASSERT_EQ(s.ndimension(), 0);
1069     ASSERT_EQ(y.size(0), 10);
1070     ASSERT_EQ(y.size(1), 2);
1071 
1072     ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1073 
1074     auto y_exp = torch::addmm(model->bias, x, model->weight.t());
1075     ASSERT_TRUE(torch::allclose(y, y_exp));
1076   }
1077   {
1078     Linear model(LinearOptions(5, 2).bias(false));
1079     auto x = torch::randn({10, 5}, torch::requires_grad());
1080     auto y = model(x);
1081     torch::Tensor s = y.sum();
1082 
1083     s.backward();
1084     ASSERT_EQ(y.ndimension(), 2);
1085     ASSERT_EQ(s.ndimension(), 0);
1086     ASSERT_EQ(y.size(0), 10);
1087     ASSERT_EQ(y.size(1), 2);
1088 
1089     ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1090 
1091     auto y_exp = torch::mm(x, model->weight.t());
1092     ASSERT_TRUE(torch::allclose(y, y_exp));
1093   }
1094 }
1095 
TEST_F(ModulesTest,LocalResponseNorm)1096 TEST_F(ModulesTest, LocalResponseNorm) {
1097   {
1098     LocalResponseNorm model(LocalResponseNormOptions(2));
1099     const auto x =
1100         torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2});
1101     auto y = model(x);
1102     const auto y_exp = torch::tensor(
1103         {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}},
1104 
1105           {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}},
1106 
1107           {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}},
1108 
1109          {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}},
1110 
1111           {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}},
1112 
1113           {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}},
1114         torch::kFloat);
1115     torch::Tensor s = y.sum();
1116 
1117     s.backward();
1118     ASSERT_EQ(y.ndimension(), 4);
1119     ASSERT_EQ(s.ndimension(), 0);
1120     ASSERT_EQ(y.sizes(), x.sizes());
1121     ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1122   }
1123 }
1124 
TEST_F(ModulesTest,LayerNorm)1125 TEST_F(ModulesTest, LayerNorm) {
1126   LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
1127   auto x = torch::randn({2, 2}, torch::requires_grad());
1128   auto y = model(x);
1129   auto y_exp = torch::layer_norm(x, {2, 2}, model->weight, model->bias, 2e-5);
1130   torch::Tensor s = y.sum();
1131 
1132   s.backward();
1133   ASSERT_EQ(y.ndimension(), 2);
1134   ASSERT_EQ(s.ndimension(), 0);
1135   for (const auto i : c10::irange(2)) {
1136     ASSERT_EQ(y.size(i), 2);
1137   }
1138 
1139   ASSERT_EQ(model->weight.grad().numel(), 2 * 2);
1140   ASSERT_TRUE(torch::allclose(y, y_exp));
1141 }
1142 
TEST_F(ModulesTest,GroupNorm)1143 TEST_F(ModulesTest, GroupNorm) {
1144   GroupNorm model(GroupNormOptions(2, 2).eps(2e-5));
1145   auto x = torch::randn({2, 2}, torch::requires_grad());
1146   auto y = model(x);
1147   auto y_exp = torch::group_norm(x, 2, model->weight, model->bias, 2e-5);
1148   torch::Tensor s = y.sum();
1149 
1150   s.backward();
1151   ASSERT_EQ(y.ndimension(), 2);
1152   ASSERT_EQ(s.ndimension(), 0);
1153   for (const auto i : c10::irange(2)) {
1154     ASSERT_EQ(y.size(i), 2);
1155   }
1156 
1157   ASSERT_EQ(model->weight.grad().numel(), 2);
1158   ASSERT_TRUE(torch::allclose(y, y_exp));
1159 }
1160 
TEST_F(ModulesTest,Bilinear)1161 TEST_F(ModulesTest, Bilinear) {
1162   Bilinear model(5, 3, 2);
1163   auto x1 = torch::randn({10, 5}, torch::requires_grad());
1164   auto x2 = torch::randn({10, 3}, torch::requires_grad());
1165   auto y = model(x1, x2);
1166   torch::Tensor s = y.sum();
1167 
1168   s.backward();
1169   ASSERT_EQ(y.ndimension(), 2);
1170   ASSERT_EQ(s.ndimension(), 0);
1171   ASSERT_EQ(y.size(0), 10);
1172   ASSERT_EQ(y.size(1), 2);
1173 
1174   ASSERT_EQ(model->weight.grad().numel(), 2 * 5 * 3);
1175 }
1176 
TEST_F(ModulesTest,Fold)1177 TEST_F(ModulesTest, Fold) {
1178   {
1179     Fold model(FoldOptions({3, 2}, {2, 2}));
1180     auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::requires_grad());
1181     auto output = model(input);
1182     auto expected = torch::tensor(
1183         {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1184           {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1185           {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1186         torch::kFloat);
1187     auto s = output.sum();
1188     s.backward();
1189 
1190     ASSERT_EQ(s.ndimension(), 0);
1191     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1192     ASSERT_TRUE(output.allclose(expected));
1193   }
1194   {
1195     // input wrong dimension
1196     Fold model(FoldOptions({8, 8}, {3, 3}));
1197     ASSERT_THROWS_WITH(
1198         model(torch::randn({1, 3, 16, 16})),
1199         "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported (got 4D)");
1200   }
1201 }
1202 
TEST_F(ModulesTest,Unfold)1203 TEST_F(ModulesTest, Unfold) {
1204   {
1205     Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2));
1206     auto input =
1207         torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3});
1208     auto output = model(input);
1209     auto expected = torch::tensor(
1210         {{{0.0, 0.0, 0.0, 6.0},
1211           {0.0, 0.0, 5.0, 7.0},
1212           {0.0, 3.0, 0.0, 0.0},
1213           {2.0, 4.0, 0.0, 0.0},
1214           {0.0, 0.0, 0.0, 12.0},
1215           {0.0, 0.0, 11.0, 13.0},
1216           {0.0, 9.0, 0.0, 0.0},
1217           {8.0, 10.0, 0.0, 0.0}}},
1218         torch::kFloat);
1219     auto s = output.sum();
1220     s.backward();
1221 
1222     ASSERT_EQ(s.ndimension(), 0);
1223     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1224     ASSERT_TRUE(output.allclose(expected));
1225   }
1226   {
1227     // input wrong dimension
1228     Unfold model(UnfoldOptions({2, 4}));
1229     ASSERT_THROWS_WITH(
1230         model(torch::randn({1, 5, 2})),
1231         "Input Error: Only 4D input Tensors are supported (got 3D)");
1232   }
1233   {
1234     // calculated output shape is too small
1235     Unfold model(UnfoldOptions({2, 3}));
1236     ASSERT_THROWS_WITH(
1237         model(torch::randn({1, 2, 2, 2})),
1238         "Given input with spatial size (2, 2), kernel_size=(2, 3), "
1239         "dilation=(1, 1), padding=(0, 0), calculated shape of the array of "
1240         "sliding blocks as (1, 0), but its components must be at least one.");
1241   }
1242 }
1243 
TEST_F(ModulesTest,SimpleContainer)1244 TEST_F(ModulesTest, SimpleContainer) {
1245   auto model = std::make_shared<SimpleContainer>();
1246   auto l1 = model->add(Linear(10, 3), "l1");
1247   auto l2 = model->add(Linear(3, 5), "l2");
1248   auto l3 = model->add(Linear(5, 100), "l3");
1249 
1250   auto x = torch::randn({1000, 10}, torch::requires_grad());
1251   x = l1(x).clamp_min(0);
1252   x = l2(x).clamp_min(0);
1253   x = l3(x).clamp_min(0);
1254 
1255   x.backward(torch::ones_like(x));
1256   ASSERT_EQ(x.ndimension(), 2);
1257   ASSERT_EQ(x.size(0), 1000);
1258   ASSERT_EQ(x.size(1), 100);
1259   ASSERT_EQ(x.min().item<float>(), 0);
1260 }
1261 
TEST_F(ModulesTest,EmbeddingBasic)1262 TEST_F(ModulesTest, EmbeddingBasic) {
1263   const int64_t dict_size = 10;
1264   Embedding model(dict_size, 2);
1265   ASSERT_TRUE(model->named_parameters().contains("weight"));
1266   ASSERT_EQ(model->weight.ndimension(), 2);
1267   ASSERT_EQ(model->weight.size(0), dict_size);
1268   ASSERT_EQ(model->weight.size(1), 2);
1269 
1270   // Cannot get gradients to change indices (input) - only for embedding
1271   // params
1272   auto x = torch::full({10}, dict_size - 1, torch::kInt64);
1273   auto y = model(x);
1274   torch::Tensor s = y.sum();
1275 
1276   s.backward();
1277   ASSERT_EQ(y.ndimension(), 2);
1278   ASSERT_EQ(s.ndimension(), 0);
1279   ASSERT_EQ(y.size(0), 10);
1280   ASSERT_EQ(y.size(1), 2);
1281 
1282   ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size);
1283 }
1284 
TEST_F(ModulesTest,EmbeddingList)1285 TEST_F(ModulesTest, EmbeddingList) {
1286   Embedding model(6, 4);
1287   auto x = torch::full({2, 3}, 5, torch::kInt64);
1288   auto y = model(x);
1289   torch::Tensor s = y.sum();
1290 
1291   s.backward();
1292   ASSERT_EQ(y.ndimension(), 3);
1293   ASSERT_EQ(y.size(0), 2);
1294   ASSERT_EQ(y.size(1), 3);
1295   ASSERT_EQ(y.size(2), 4);
1296 }
1297 
TEST_F(ModulesTest,EmbeddingFromPretrained)1298 TEST_F(ModulesTest, EmbeddingFromPretrained) {
1299   auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1300   Embedding embedding = torch::nn::Embedding::from_pretrained(weight);
1301   auto input = torch::tensor({1}, torch::kLong);
1302   ASSERT_TRUE(torch::allclose(
1303       embedding(input), torch::tensor({4.0000, 5.1000, 6.3000})));
1304 }
1305 
TEST_F(ModulesTest,EmbeddingBagFromPretrained)1306 TEST_F(ModulesTest, EmbeddingBagFromPretrained) {
1307   auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1308   EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight);
1309   auto input = torch::zeros({{1, 2}}, torch::kLong);
1310   input[0] = torch::tensor({1, 0});
1311   ASSERT_TRUE(torch::allclose(
1312       embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500})));
1313 }
1314 
TEST_F(ModulesTest,AlphaDropout)1315 TEST_F(ModulesTest, AlphaDropout) {
1316   AlphaDropout alpha_dropout(0.5);
1317   torch::Tensor x = torch::ones(100, torch::requires_grad());
1318   torch::Tensor y = alpha_dropout(x);
1319 
1320   y.backward(torch::ones_like(y));
1321 
1322   ASSERT_EQ(y.ndimension(), 1);
1323   ASSERT_EQ(y.size(0), 100);
1324   ASSERT_LT(y.sum().item<float>(), 130); // Probably
1325   ASSERT_GT(y.sum().item<float>(), 40); // Probably
1326 
1327   alpha_dropout->eval();
1328   y = alpha_dropout(x);
1329 
1330   ASSERT_EQ(y.sum().item<float>(), 100);
1331 }
1332 
TEST_F(ModulesTest,FeatureAlphaDropout)1333 TEST_F(ModulesTest, FeatureAlphaDropout) {
1334   FeatureAlphaDropout feature_alpha_dropout(0.5);
1335   torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1336   torch::Tensor y = feature_alpha_dropout(x);
1337 
1338   y.backward(torch::ones_like(y));
1339 
1340   ASSERT_EQ(y.ndimension(), 2);
1341   ASSERT_EQ(y.size(0), 10);
1342   ASSERT_EQ(y.size(1), 10);
1343   ASSERT_LT(y.sum().item<float>(), 130); // Probably
1344   ASSERT_GT(y.sum().item<float>(), 40); // Probably
1345 
1346   feature_alpha_dropout->eval();
1347   y = feature_alpha_dropout(x);
1348 
1349   ASSERT_EQ(y.sum().item<float>(), 100);
1350 }
1351 
TEST_F(ModulesTest,Dropout)1352 TEST_F(ModulesTest, Dropout) {
1353   for (const auto inplace : {false, true}) {
1354     Dropout dropout(DropoutOptions(0.5).inplace(inplace));
1355     torch::Tensor x = torch::ones(100);
1356     if (!inplace) {
1357       x.requires_grad_(true);
1358     }
1359     torch::Tensor y = dropout(x);
1360 
1361     ASSERT_EQ(y.ndimension(), 1);
1362     ASSERT_EQ(y.size(0), 100);
1363     ASSERT_LT(y.sum().item<float>(), 130); // Probably
1364     ASSERT_GT(y.sum().item<float>(), 70); // Probably
1365     if (inplace) {
1366       ASSERT_TRUE(y.allclose(x));
1367     } else {
1368       y.backward(torch::ones_like(y));
1369     }
1370 
1371     dropout->eval();
1372     y = dropout(torch::ones(100));
1373     ASSERT_EQ(y.sum().item<float>(), 100);
1374   }
1375 }
1376 
TEST_F(ModulesTest,Dropout2d)1377 TEST_F(ModulesTest, Dropout2d) {
1378   auto p = 0.5;
1379   for (const auto inplace : {false, true}) {
1380     Dropout2d dropout(Dropout2dOptions(p).inplace(inplace));
1381     torch::Tensor x = torch::empty({50, 50, 2, 2}).fill_(1 - p);
1382     if (!inplace) {
1383       x.requires_grad_(true);
1384     }
1385     torch::Tensor y = dropout(x);
1386 
1387     ASSERT_EQ(y.ndimension(), 4);
1388     ASSERT_EQ(y.size(0), 50);
1389     ASSERT_EQ(y.size(1), 50);
1390     ASSERT_EQ(y.size(2), 2);
1391     ASSERT_EQ(y.size(3), 2);
1392     ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1393 
1394     if (inplace) {
1395       ASSERT_TRUE(y.allclose(x));
1396     } else {
1397       y.backward(torch::ones_like(y));
1398     }
1399 
1400     dropout->eval();
1401     y = dropout(torch::ones({2, 2, 10, 10}));
1402     ASSERT_EQ(y.sum().item<float>(), 400);
1403   }
1404 }
1405 
TEST_F(ModulesTest,Dropout3d)1406 TEST_F(ModulesTest, Dropout3d) {
1407   for (const auto inplace : {false, true}) {
1408     auto p = 0.5;
1409     Dropout3d dropout(Dropout3dOptions(p).inplace(inplace));
1410     torch::Tensor x = torch::empty({50, 50, 2, 2, 2}).fill_(1 - p);
1411     if (!inplace) {
1412       x.requires_grad_(true);
1413     }
1414     torch::Tensor y = dropout(x);
1415 
1416     ASSERT_EQ(y.ndimension(), 5);
1417     ASSERT_EQ(y.size(0), 50);
1418     ASSERT_EQ(y.size(1), 50);
1419     ASSERT_EQ(y.size(2), 2);
1420     ASSERT_EQ(y.size(3), 2);
1421     ASSERT_EQ(y.size(4), 2);
1422     ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1423 
1424     if (inplace) {
1425       ASSERT_TRUE(y.allclose(x));
1426     } else {
1427       y.backward(torch::ones_like(y));
1428     }
1429 
1430     dropout->eval();
1431     y = dropout(torch::ones({4, 4, 5, 5}));
1432     ASSERT_EQ(y.sum().item<float>(), 400);
1433   }
1434 }
1435 
TEST_F(ModulesTest,Parameters)1436 TEST_F(ModulesTest, Parameters) {
1437   auto model = std::make_shared<NestedModel>();
1438   auto parameters = model->named_parameters();
1439   ASSERT_EQ(parameters["param"].size(0), 3);
1440   ASSERT_EQ(parameters["param"].size(1), 2);
1441   ASSERT_EQ(parameters["param"].size(2), 21);
1442   ASSERT_EQ(parameters["l1.bias"].size(0), 20);
1443   ASSERT_EQ(parameters["l1.weight"].size(0), 20);
1444   ASSERT_EQ(parameters["l1.weight"].size(1), 5);
1445   ASSERT_EQ(parameters["test.l1.bias"].size(0), 3);
1446   ASSERT_EQ(parameters["test.l1.weight"].size(0), 3);
1447   ASSERT_EQ(parameters["test.l1.weight"].size(1), 10);
1448   ASSERT_EQ(parameters["test.l2.bias"].size(0), 5);
1449   ASSERT_EQ(parameters["test.l2.weight"].size(0), 5);
1450   ASSERT_EQ(parameters["test.l2.weight"].size(1), 3);
1451   ASSERT_EQ(parameters["test.l3.bias"].size(0), 100);
1452   ASSERT_EQ(parameters["test.l3.weight"].size(0), 100);
1453   ASSERT_EQ(parameters["test.l3.weight"].size(1), 5);
1454 }
1455 
TEST_F(ModulesTest,FunctionalCallsSuppliedFunction)1456 TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
1457   bool was_called = false;
1458   auto functional = Functional([&was_called](torch::Tensor input) {
1459     was_called = true;
1460     return input;
1461   });
1462   auto output = functional(torch::ones(5, torch::requires_grad()));
1463   ASSERT_TRUE(was_called);
1464   ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1465 
1466   was_called = false;
1467   // Use the call operator overload here.
1468   output = functional(torch::ones(5, torch::requires_grad()));
1469   ASSERT_TRUE(was_called);
1470   ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1471 }
1472 
TEST_F(ModulesTest,FunctionalWithTorchFunction)1473 TEST_F(ModulesTest, FunctionalWithTorchFunction) {
1474   auto functional = Functional(torch::relu);
1475   ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1476   ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1477   ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
1478 }
1479 
TEST_F(ModulesTest,FunctionalArgumentBinding)1480 TEST_F(ModulesTest, FunctionalArgumentBinding) {
1481   auto functional =
1482       Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
1483   ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
1484 }
1485 
TEST_F(ModulesTest,BatchNorm1dStateful)1486 TEST_F(ModulesTest, BatchNorm1dStateful) {
1487   BatchNorm1d bn(5);
1488 
1489   ASSERT_TRUE(bn->options.track_running_stats());
1490 
1491   ASSERT_TRUE(bn->running_mean.defined());
1492   ASSERT_EQ(bn->running_mean.dim(), 1);
1493   ASSERT_EQ(bn->running_mean.size(0), 5);
1494 
1495   ASSERT_TRUE(bn->running_var.defined());
1496   ASSERT_EQ(bn->running_var.dim(), 1);
1497   ASSERT_EQ(bn->running_var.size(0), 5);
1498 
1499   ASSERT_TRUE(bn->num_batches_tracked.defined());
1500   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1501 
1502   ASSERT_TRUE(bn->options.affine());
1503 
1504   ASSERT_TRUE(bn->weight.defined());
1505   ASSERT_EQ(bn->weight.dim(), 1);
1506   ASSERT_EQ(bn->weight.size(0), 5);
1507 
1508   ASSERT_TRUE(bn->bias.defined());
1509   ASSERT_EQ(bn->bias.dim(), 1);
1510   ASSERT_EQ(bn->bias.size(0), 5);
1511 }
1512 
TEST_F(ModulesTest,BatchNorm1dStateless)1513 TEST_F(ModulesTest, BatchNorm1dStateless) {
1514   BatchNorm1d bn(
1515       BatchNorm1dOptions(5).track_running_stats(false).affine(false));
1516 
1517   ASSERT_FALSE(bn->running_mean.defined());
1518   ASSERT_FALSE(bn->running_var.defined());
1519   ASSERT_FALSE(bn->num_batches_tracked.defined());
1520   ASSERT_FALSE(bn->weight.defined());
1521   ASSERT_FALSE(bn->bias.defined());
1522 }
1523 
TEST_F(ModulesTest,BatchNorm1d)1524 TEST_F(ModulesTest, BatchNorm1d) {
1525   BatchNorm1d bn(5);
1526   bn->eval();
1527 
1528   auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1529   auto output = bn->forward(input);
1530   auto expected = torch::tensor(
1531       {{{0.0000, 1.0000},
1532         {2.0000, 3.0000},
1533         {4.0000, 5.0000},
1534         {6.0000, 7.0000},
1535         {8.0000, 9.0000}},
1536        {{10.0000, 10.9999},
1537         {11.9999, 12.9999},
1538         {13.9999, 14.9999},
1539         {15.9999, 16.9999},
1540         {17.9999, 18.9999}}});
1541   ASSERT_TRUE(output.allclose(expected));
1542   auto s = output.sum();
1543   s.backward();
1544 
1545   ASSERT_EQ(input.sizes(), input.grad().sizes());
1546 }
1547 
TEST_F(ModulesTest,BatchNorm2dStateful)1548 TEST_F(ModulesTest, BatchNorm2dStateful) {
1549   BatchNorm2d bn(5);
1550 
1551   ASSERT_TRUE(bn->options.track_running_stats());
1552 
1553   ASSERT_TRUE(bn->running_mean.defined());
1554   ASSERT_EQ(bn->running_mean.dim(), 1);
1555   ASSERT_EQ(bn->running_mean.size(0), 5);
1556 
1557   ASSERT_TRUE(bn->running_var.defined());
1558   ASSERT_EQ(bn->running_var.dim(), 1);
1559   ASSERT_EQ(bn->running_var.size(0), 5);
1560 
1561   ASSERT_TRUE(bn->num_batches_tracked.defined());
1562   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1563 
1564   ASSERT_TRUE(bn->options.affine());
1565 
1566   ASSERT_TRUE(bn->weight.defined());
1567   ASSERT_EQ(bn->weight.dim(), 1);
1568   ASSERT_EQ(bn->weight.size(0), 5);
1569 
1570   ASSERT_TRUE(bn->bias.defined());
1571   ASSERT_EQ(bn->bias.dim(), 1);
1572   ASSERT_EQ(bn->bias.size(0), 5);
1573 }
1574 
TEST_F(ModulesTest,BatchNorm2dStateless)1575 TEST_F(ModulesTest, BatchNorm2dStateless) {
1576   BatchNorm2d bn(
1577       BatchNorm2dOptions(5).track_running_stats(false).affine(false));
1578 
1579   ASSERT_FALSE(bn->running_mean.defined());
1580   ASSERT_FALSE(bn->running_var.defined());
1581   ASSERT_FALSE(bn->num_batches_tracked.defined());
1582   ASSERT_FALSE(bn->weight.defined());
1583   ASSERT_FALSE(bn->bias.defined());
1584 }
1585 
TEST_F(ModulesTest,BatchNorm2d)1586 TEST_F(ModulesTest, BatchNorm2d) {
1587   BatchNorm2d bn(5);
1588   bn->eval();
1589 
1590   auto input =
1591       torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1592   auto output = bn->forward(input);
1593   auto expected = torch::tensor(
1594       {{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1595         {{4.0000, 5.0000}, {6.0000, 7.0000}},
1596         {{8.0000, 9.0000}, {10.0000, 10.9999}},
1597         {{11.9999, 12.9999}, {13.9999, 14.9999}},
1598         {{15.9999, 16.9999}, {17.9999, 18.9999}}},
1599        {{{19.9999, 20.9999}, {21.9999, 22.9999}},
1600         {{23.9999, 24.9999}, {25.9999, 26.9999}},
1601         {{27.9999, 28.9999}, {29.9998, 30.9998}},
1602         {{31.9998, 32.9998}, {33.9998, 34.9998}},
1603         {{35.9998, 36.9998}, {37.9998, 38.9998}}}});
1604   ASSERT_TRUE(output.allclose(expected));
1605   auto s = output.sum();
1606   s.backward();
1607 
1608   ASSERT_EQ(input.sizes(), input.grad().sizes());
1609 }
1610 
TEST_F(ModulesTest,BatchNorm3dStateful)1611 TEST_F(ModulesTest, BatchNorm3dStateful) {
1612   BatchNorm3d bn(5);
1613 
1614   ASSERT_TRUE(bn->options.track_running_stats());
1615 
1616   ASSERT_TRUE(bn->running_mean.defined());
1617   ASSERT_EQ(bn->running_mean.dim(), 1);
1618   ASSERT_EQ(bn->running_mean.size(0), 5);
1619 
1620   ASSERT_TRUE(bn->running_var.defined());
1621   ASSERT_EQ(bn->running_var.dim(), 1);
1622   ASSERT_EQ(bn->running_var.size(0), 5);
1623 
1624   ASSERT_TRUE(bn->num_batches_tracked.defined());
1625   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1626 
1627   ASSERT_TRUE(bn->options.affine());
1628 
1629   ASSERT_TRUE(bn->weight.defined());
1630   ASSERT_EQ(bn->weight.dim(), 1);
1631   ASSERT_EQ(bn->weight.size(0), 5);
1632 
1633   ASSERT_TRUE(bn->bias.defined());
1634   ASSERT_EQ(bn->bias.dim(), 1);
1635   ASSERT_EQ(bn->bias.size(0), 5);
1636 }
1637 
TEST_F(ModulesTest,BatchNorm3dStateless)1638 TEST_F(ModulesTest, BatchNorm3dStateless) {
1639   BatchNorm3d bn(
1640       BatchNorm3dOptions(5).track_running_stats(false).affine(false));
1641 
1642   ASSERT_FALSE(bn->running_mean.defined());
1643   ASSERT_FALSE(bn->running_var.defined());
1644   ASSERT_FALSE(bn->num_batches_tracked.defined());
1645   ASSERT_FALSE(bn->weight.defined());
1646   ASSERT_FALSE(bn->bias.defined());
1647 }
1648 
TEST_F(ModulesTest,BatchNorm3d)1649 TEST_F(ModulesTest, BatchNorm3d) {
1650   BatchNorm3d bn(5);
1651   bn->eval();
1652 
1653   auto input =
1654       torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1655   auto output = bn->forward(input);
1656   auto expected = torch::tensor(
1657       {{{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1658          {{4.0000, 5.0000}, {6.0000, 7.0000}}},
1659         {{{8.0000, 9.0000}, {10.0000, 10.9999}},
1660          {{11.9999, 12.9999}, {13.9999, 14.9999}}},
1661         {{{15.9999, 16.9999}, {17.9999, 18.9999}},
1662          {{19.9999, 20.9999}, {21.9999, 22.9999}}},
1663         {{{23.9999, 24.9999}, {25.9999, 26.9999}},
1664          {{27.9999, 28.9999}, {29.9998, 30.9998}}},
1665         {{{31.9998, 32.9998}, {33.9998, 34.9998}},
1666          {{35.9998, 36.9998}, {37.9998, 38.9998}}}},
1667        {{{{39.9998, 40.9998}, {41.9998, 42.9998}},
1668          {{43.9998, 44.9998}, {45.9998, 46.9998}}},
1669         {{{47.9998, 48.9998}, {49.9997, 50.9997}},
1670          {{51.9997, 52.9997}, {53.9997, 54.9997}}},
1671         {{{55.9997, 56.9997}, {57.9997, 58.9997}},
1672          {{59.9997, 60.9997}, {61.9997, 62.9997}}},
1673         {{{63.9997, 64.9997}, {65.9997, 66.9997}},
1674          {{67.9997, 68.9997}, {69.9996, 70.9996}}},
1675         {{{71.9996, 72.9996}, {73.9996, 74.9996}},
1676          {{75.9996, 76.9996}, {77.9996, 78.9996}}}}});
1677   ASSERT_TRUE(output.allclose(expected));
1678   auto s = output.sum();
1679   s.backward();
1680 
1681   ASSERT_EQ(input.sizes(), input.grad().sizes());
1682 }
1683 
TEST_F(ModulesTest,InstanceNorm1dStateful)1684 TEST_F(ModulesTest, InstanceNorm1dStateful) {
1685   InstanceNorm1d instance_norm(
1686       InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
1687 
1688   ASSERT_TRUE(instance_norm->options.track_running_stats());
1689 
1690   ASSERT_TRUE(instance_norm->running_mean.defined());
1691   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1692   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1693 
1694   ASSERT_TRUE(instance_norm->running_var.defined());
1695   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1696   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1697 
1698   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1699   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1700 
1701   ASSERT_TRUE(instance_norm->options.affine());
1702 
1703   ASSERT_TRUE(instance_norm->weight.defined());
1704   ASSERT_EQ(instance_norm->weight.dim(), 1);
1705   ASSERT_EQ(instance_norm->weight.size(0), 5);
1706 
1707   ASSERT_TRUE(instance_norm->bias.defined());
1708   ASSERT_EQ(instance_norm->bias.dim(), 1);
1709   ASSERT_EQ(instance_norm->bias.size(0), 5);
1710 }
1711 
TEST_F(ModulesTest,InstanceNorm1dStateless)1712 TEST_F(ModulesTest, InstanceNorm1dStateless) {
1713   InstanceNorm1d instance_norm(
1714       InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
1715 
1716   ASSERT_FALSE(instance_norm->running_mean.defined());
1717   ASSERT_FALSE(instance_norm->running_var.defined());
1718   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1719   ASSERT_FALSE(instance_norm->weight.defined());
1720   ASSERT_FALSE(instance_norm->bias.defined());
1721 }
1722 
TEST_F(ModulesTest,InstanceNorm1d)1723 TEST_F(ModulesTest, InstanceNorm1d) {
1724   InstanceNorm1d instance_norm(5);
1725   instance_norm->eval();
1726 
1727   auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1728   auto output = instance_norm->forward(input);
1729   auto expected = torch::tensor(
1730       {{{-1.0000, 1.0000},
1731         {-1.0000, 1.0000},
1732         {-1.0000, 1.0000},
1733         {-1.0000, 1.0000},
1734         {-1.0000, 1.0000}},
1735        {{-1.0000, 1.0000},
1736         {-1.0000, 1.0000},
1737         {-1.0000, 1.0000},
1738         {-1.0000, 1.0000},
1739         {-1.0000, 1.0000}}});
1740   ASSERT_TRUE(output.allclose(expected, 1e-3));
1741   auto s = output.sum();
1742   s.backward();
1743 
1744   ASSERT_EQ(input.sizes(), input.grad().sizes());
1745 }
1746 
TEST_F(ModulesTest,InstanceNorm2dStateful)1747 TEST_F(ModulesTest, InstanceNorm2dStateful) {
1748   InstanceNorm2d instance_norm(
1749       InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
1750 
1751   ASSERT_TRUE(instance_norm->options.track_running_stats());
1752 
1753   ASSERT_TRUE(instance_norm->running_mean.defined());
1754   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1755   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1756 
1757   ASSERT_TRUE(instance_norm->running_var.defined());
1758   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1759   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1760 
1761   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1762   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1763 
1764   ASSERT_TRUE(instance_norm->options.affine());
1765 
1766   ASSERT_TRUE(instance_norm->weight.defined());
1767   ASSERT_EQ(instance_norm->weight.dim(), 1);
1768   ASSERT_EQ(instance_norm->weight.size(0), 5);
1769 
1770   ASSERT_TRUE(instance_norm->bias.defined());
1771   ASSERT_EQ(instance_norm->bias.dim(), 1);
1772   ASSERT_EQ(instance_norm->bias.size(0), 5);
1773 }
1774 
TEST_F(ModulesTest,InstanceNorm2dStateless)1775 TEST_F(ModulesTest, InstanceNorm2dStateless) {
1776   InstanceNorm2d instance_norm(
1777       InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
1778 
1779   ASSERT_FALSE(instance_norm->running_mean.defined());
1780   ASSERT_FALSE(instance_norm->running_var.defined());
1781   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1782   ASSERT_FALSE(instance_norm->weight.defined());
1783   ASSERT_FALSE(instance_norm->bias.defined());
1784 }
1785 
TEST_F(ModulesTest,InstanceNorm2d)1786 TEST_F(ModulesTest, InstanceNorm2d) {
1787   InstanceNorm2d instance_norm(5);
1788   instance_norm->eval();
1789 
1790   auto input =
1791       torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1792   auto output = instance_norm->forward(input);
1793   auto expected = torch::tensor(
1794       {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1795         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1796         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1797         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1798         {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
1799        {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1800         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1801         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1802         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1803         {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
1804   ASSERT_TRUE(output.allclose(expected, 1e-3));
1805   auto s = output.sum();
1806   s.backward();
1807 
1808   ASSERT_EQ(input.sizes(), input.grad().sizes());
1809 }
1810 
TEST_F(ModulesTest,InstanceNorm3dStateful)1811 TEST_F(ModulesTest, InstanceNorm3dStateful) {
1812   InstanceNorm3d instance_norm(
1813       InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
1814 
1815   ASSERT_TRUE(instance_norm->options.track_running_stats());
1816 
1817   ASSERT_TRUE(instance_norm->running_mean.defined());
1818   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1819   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1820 
1821   ASSERT_TRUE(instance_norm->running_var.defined());
1822   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1823   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1824 
1825   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1826   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1827 
1828   ASSERT_TRUE(instance_norm->options.affine());
1829 
1830   ASSERT_TRUE(instance_norm->weight.defined());
1831   ASSERT_EQ(instance_norm->weight.dim(), 1);
1832   ASSERT_EQ(instance_norm->weight.size(0), 5);
1833 
1834   ASSERT_TRUE(instance_norm->bias.defined());
1835   ASSERT_EQ(instance_norm->bias.dim(), 1);
1836   ASSERT_EQ(instance_norm->bias.size(0), 5);
1837 }
1838 
TEST_F(ModulesTest,InstanceNorm3dStateless)1839 TEST_F(ModulesTest, InstanceNorm3dStateless) {
1840   InstanceNorm3d instance_norm(
1841       InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
1842 
1843   ASSERT_FALSE(instance_norm->running_mean.defined());
1844   ASSERT_FALSE(instance_norm->running_var.defined());
1845   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1846   ASSERT_FALSE(instance_norm->weight.defined());
1847   ASSERT_FALSE(instance_norm->bias.defined());
1848 }
1849 
TEST_F(ModulesTest,InstanceNorm3d)1850 TEST_F(ModulesTest, InstanceNorm3d) {
1851   InstanceNorm3d instance_norm(5);
1852   instance_norm->eval();
1853 
1854   auto input =
1855       torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1856   auto output = instance_norm->forward(input);
1857   auto expected = torch::tensor(
1858       {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1859          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1860         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1861          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1862         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1863          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1864         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1865          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1866         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1867          {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
1868        {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1869          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1870         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1871          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1872         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1873          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1874         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1875          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1876         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1877          {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
1878   ASSERT_TRUE(output.allclose(expected, 1e-3));
1879   auto s = output.sum();
1880   s.backward();
1881 
1882   ASSERT_EQ(input.sizes(), input.grad().sizes());
1883 }
1884 
TEST_F(ModulesTest,Linear_CUDA)1885 TEST_F(ModulesTest, Linear_CUDA) {
1886   Linear model(5, 2);
1887   model->to(torch::kCUDA);
1888   auto x =
1889       torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
1890   auto y = model(x);
1891   torch::Tensor s = y.sum();
1892 
1893   s.backward();
1894   ASSERT_EQ(y.ndimension(), 2);
1895   ASSERT_EQ(s.ndimension(), 0);
1896   ASSERT_EQ(y.size(0), 10);
1897   ASSERT_EQ(y.size(1), 2);
1898 
1899   ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1900 }
1901 
TEST_F(ModulesTest,Linear2_CUDA)1902 TEST_F(ModulesTest, Linear2_CUDA) {
1903   Linear model(5, 2);
1904   model->to(torch::kCUDA);
1905   model->to(torch::kCPU);
1906   auto x = torch::randn({10, 5}, torch::requires_grad());
1907   auto y = model(x);
1908   torch::Tensor s = y.sum();
1909 
1910   s.backward();
1911   ASSERT_EQ(y.ndimension(), 2);
1912   ASSERT_EQ(s.ndimension(), 0);
1913   ASSERT_EQ(y.size(0), 10);
1914   ASSERT_EQ(y.size(1), 2);
1915 
1916   ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1917 }
1918 
TEST_F(ModulesTest,L1Loss)1919 TEST_F(ModulesTest, L1Loss) {
1920   L1Loss loss;
1921   auto input = torch::randn({5, 6}, torch::requires_grad());
1922   auto target = torch::empty({5, 6}).random_(2);
1923   auto output = loss->forward(torch::sigmoid(input), target);
1924   auto s = output.sum();
1925   s.backward();
1926 
1927   ASSERT_EQ(output.sizes(), std::vector<int64_t>());
1928   ASSERT_EQ(input.sizes(), input.grad().sizes());
1929 }
1930 
TEST_F(ModulesTest,MSELoss)1931 TEST_F(ModulesTest, MSELoss) {
1932   MSELoss loss;
1933   auto input = torch::randn({5, 6}, torch::requires_grad());
1934   auto target = torch::empty({5, 6}).random_(2);
1935   auto output = loss->forward(torch::sigmoid(input), target);
1936   auto s = output.sum();
1937   s.backward();
1938 
1939   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1940   ASSERT_EQ(input.sizes(), input.grad().sizes());
1941 }
1942 
TEST_F(ModulesTest,BCELoss)1943 TEST_F(ModulesTest, BCELoss) {
1944   BCELoss loss;
1945   auto input = torch::randn({5, 6}, torch::requires_grad());
1946   auto target = torch::empty({5, 6}).random_(2);
1947   auto output = loss->forward(torch::sigmoid(input), target);
1948   auto s = output.sum();
1949   s.backward();
1950 
1951   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1952   ASSERT_EQ(input.sizes(), input.grad().sizes());
1953 }
1954 
TEST_F(ModulesTest,KLDivLoss)1955 TEST_F(ModulesTest, KLDivLoss) {
1956   KLDivLoss loss;
1957   auto input = torch::randn({5, 6}, torch::requires_grad());
1958   auto target = torch::empty({5, 6}).random_(2);
1959   auto output = loss->forward(torch::sigmoid(input), target);
1960   auto s = output.sum();
1961   s.backward();
1962 
1963   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1964   ASSERT_EQ(input.sizes(), input.grad().sizes());
1965 }
1966 
TEST_F(ModulesTest,HingeEmbeddingLoss)1967 TEST_F(ModulesTest, HingeEmbeddingLoss) {
1968   HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2));
1969   auto input = torch::tensor(
1970       {{2, 22, 4}, {20, 10, 0}},
1971       torch::dtype(torch::kFloat).requires_grad(true));
1972   auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
1973   auto output = loss->forward(input, target);
1974   auto expected = torch::tensor({10}, torch::kFloat);
1975   auto s = output.sum();
1976   s.backward();
1977 
1978   ASSERT_TRUE(output.allclose(expected));
1979   ASSERT_EQ(input.sizes(), input.grad().sizes());
1980 }
1981 
TEST_F(ModulesTest,MultiMarginLoss)1982 TEST_F(ModulesTest, MultiMarginLoss) {
1983   auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
1984   MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight));
1985   auto input = torch::tensor(
1986       {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
1987       torch::dtype(torch::kFloat).requires_grad(true));
1988   auto target = torch::tensor({2, 1, 0}, torch::kLong);
1989   auto output = loss->forward(input, target);
1990   auto expected = torch::tensor({0.305556}, torch::kFloat);
1991   auto s = output.sum();
1992   s.backward();
1993 
1994   ASSERT_TRUE(output.allclose(expected, 1e-04));
1995   ASSERT_EQ(input.sizes(), input.grad().sizes());
1996 }
1997 
TEST_F(ModulesTest,CosineEmbeddingLoss)1998 TEST_F(ModulesTest, CosineEmbeddingLoss) {
1999   CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5));
2000   auto input1 = torch::tensor(
2001       {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
2002   auto input2 = torch::tensor(
2003       {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true));
2004   auto target = torch::tensor({1, -1});
2005   auto output = cos(input1, input2, target);
2006   auto expected = torch::tensor({0.1004}, torch::kFloat);
2007   auto s = output.sum();
2008   s.backward();
2009 
2010   ASSERT_TRUE(output.allclose(expected, 1e-4));
2011   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2012   ASSERT_EQ(input2.sizes(), input2.grad().sizes());
2013 }
2014 
TEST_F(ModulesTest,SmoothL1LossDefaultOptions)2015 TEST_F(ModulesTest, SmoothL1LossDefaultOptions) {
2016   SmoothL1Loss loss;
2017   auto input = torch::tensor(
2018       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2019   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2020   auto output = loss(input, target);
2021   auto expected = torch::tensor(0.0233335, torch::kFloat);
2022   auto s = output.sum();
2023   s.backward();
2024 
2025   ASSERT_TRUE(output.allclose(expected));
2026   ASSERT_EQ(input.sizes(), input.grad().sizes());
2027 }
2028 
TEST_F(ModulesTest,HuberLossDefaultOptions)2029 TEST_F(ModulesTest, HuberLossDefaultOptions) {
2030   HuberLoss loss;
2031   auto input = torch::tensor(
2032       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2033   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2034   auto output = loss(input, target);
2035   auto expected = torch::tensor(0.0233335, torch::kFloat);
2036   auto s = output.sum();
2037   s.backward();
2038 
2039   ASSERT_TRUE(output.allclose(expected));
2040   ASSERT_EQ(input.sizes(), input.grad().sizes());
2041 }
2042 
TEST_F(ModulesTest,MultiLabelMarginLossDefaultOptions)2043 TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) {
2044   MultiLabelMarginLoss loss;
2045   auto input = torch::tensor(
2046       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2047   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2048   auto output = loss->forward(input, target);
2049   auto expected = torch::tensor({0.8500}, torch::kFloat);
2050   auto s = output.sum();
2051   s.backward();
2052 
2053   ASSERT_TRUE(output.allclose(expected));
2054   ASSERT_EQ(input.sizes(), input.grad().sizes());
2055 }
2056 
TEST_F(ModulesTest,SmoothL1LossNoReduction)2057 TEST_F(ModulesTest, SmoothL1LossNoReduction) {
2058   SmoothL1Loss loss(/*reduction=*/torch::kNone);
2059   auto input = torch::tensor(
2060       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2061   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2062   auto output = loss(input, target);
2063   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2064   auto s = output.sum();
2065   s.backward();
2066 
2067   ASSERT_TRUE(output.allclose(expected));
2068   ASSERT_EQ(input.sizes(), input.grad().sizes());
2069 }
2070 
TEST_F(ModulesTest,HuberLossNoReduction)2071 TEST_F(ModulesTest, HuberLossNoReduction) {
2072   HuberLoss loss(/*reduction=*/torch::kNone);
2073   auto input = torch::tensor(
2074       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2075   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2076   auto output = loss(input, target);
2077   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2078   auto s = output.sum();
2079   s.backward();
2080 
2081   ASSERT_TRUE(output.allclose(expected));
2082   ASSERT_EQ(input.sizes(), input.grad().sizes());
2083 }
2084 
TEST_F(ModulesTest,MultiLabelMarginLossNoReduction)2085 TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) {
2086   MultiLabelMarginLoss loss(torch::kNone);
2087   auto input = torch::tensor(
2088       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2089   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2090   auto output = loss->forward(input, target);
2091   auto expected = torch::tensor({0.8500}, torch::kFloat);
2092   auto s = output.sum();
2093   s.backward();
2094 
2095   ASSERT_TRUE(output.allclose(expected));
2096   ASSERT_EQ(input.sizes(), input.grad().sizes());
2097 }
2098 
TEST_F(ModulesTest,SmoothL1LossBeta)2099 TEST_F(ModulesTest, SmoothL1LossBeta) {
2100   auto options = SmoothL1LossOptions().beta(0.2);
2101   SmoothL1Loss loss(options);
2102   auto input = torch::tensor(
2103       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2104   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2105   auto output = loss(input, target);
2106   auto expected = torch::tensor(0.108333, torch::kFloat);
2107   auto s = output.sum();
2108   s.backward();
2109 
2110   ASSERT_TRUE(output.allclose(expected));
2111   ASSERT_EQ(input.sizes(), input.grad().sizes());
2112 }
2113 
TEST_F(ModulesTest,HuberLossDelta)2114 TEST_F(ModulesTest, HuberLossDelta) {
2115   auto options = HuberLossOptions().delta(0.2);
2116   HuberLoss loss(options);
2117   auto input = torch::tensor(
2118       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2119   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2120   auto output = loss(input, target);
2121   auto expected = torch::tensor(0.0216666, torch::kFloat);
2122   auto s = output.sum();
2123   s.backward();
2124 
2125   ASSERT_TRUE(output.allclose(expected));
2126   ASSERT_EQ(input.sizes(), input.grad().sizes());
2127 }
2128 
TEST_F(ModulesTest,TripletMarginLoss)2129 TEST_F(ModulesTest, TripletMarginLoss) {
2130   TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0));
2131   auto anchor = torch::tensor(
2132       {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true));
2133   auto positive = torch::tensor(
2134       {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2135   auto negative = torch::tensor(
2136       {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true));
2137   auto output = loss->forward(anchor, positive, negative);
2138   auto expected = torch::tensor({0.}, torch::kFloat);
2139   auto s = output.sum();
2140   s.backward();
2141 
2142   ASSERT_TRUE(output.allclose(expected, 1e-04));
2143   ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2144 }
2145 
TEST_F(ModulesTest,TripletMarginWithDistanceLossDefaultParity)2146 TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
2147   // Check that if we use torch::pairwise_distance with the default
2148   // TripletMarginLoss options as our distance function, the outputs
2149   // are equal (i.e., equal under defaults).
2150 
2151   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2152       torch::kSum, torch::kMean, torch::kNone};
2153   std::vector<float> margins = {0.5, 1.0, 1.5};
2154   std::vector<bool> swaps = {true, false};
2155 
2156   for (auto& reduction : reductions) {
2157     for (auto& margin : margins) {
2158       for (const auto swap : swaps) {
2159         auto anchor = torch::randn(
2160             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2161         auto positive = torch::randn(
2162             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2163         auto negative = torch::randn(
2164             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2165 
2166         auto basicOptions =
2167             TripletMarginLossOptions().reduction(reduction).margin(margin).swap(
2168                 swap);
2169         auto distanceOptions = TripletMarginWithDistanceLossOptions()
2170                                    .reduction(reduction)
2171                                    .margin(margin)
2172                                    .swap(swap);
2173         TripletMarginLoss basicLoss(basicOptions);
2174         TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
2175 
2176         auto basicOutput = basicLoss->forward(anchor, positive, negative);
2177         auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
2178         auto basicOperatorOutput = basicLoss(anchor, positive, negative);
2179         auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);
2180 
2181         ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
2182         ASSERT_TRUE(
2183             distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
2184         ASSERT_TRUE(
2185             distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));
2186 
2187         // handle for torch::kNone reduction
2188         auto sum = distanceOutput.sum();
2189         sum.backward();
2190         ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2191         ASSERT_EQ(positive.sizes(), positive.grad().sizes());
2192         ASSERT_EQ(negative.sizes(), negative.grad().sizes());
2193       }
2194     }
2195   }
2196 }
2197 
TEST_F(ModulesTest,TripletMarginWithDistanceLossFunctionalParity)2198 TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
2199   // Check for parity between F::triplet_margin_with_distance_loss and
2200   // TripletMarginWithDistanceLoss.
2201   auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2202     return torch::pairwise_distance(x, y);
2203   };
2204   auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2205     return 1.0 - torch::cosine_similarity(x, y);
2206   };
2207   std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
2208       distance_functions = {pairwise_distance, cosine_distance};
2209 
2210   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2211       torch::kSum, torch::kMean, torch::kNone};
2212   std::vector<float> margins = {0.5, 1.0, 1.5};
2213   std::vector<bool> swaps = {true, false};
2214 
2215   for (auto& function : distance_functions) {
2216     for (auto& reduction : reductions) {
2217       for (auto& margin : margins) {
2218         for (const auto swap : swaps) {
2219           auto moduleOptions = TripletMarginWithDistanceLossOptions()
2220                                    .distance_function(function)
2221                                    .reduction(reduction)
2222                                    .margin(margin)
2223                                    .swap(swap);
2224           auto functionOptions =
2225               torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
2226                   .distance_function(function)
2227                   .reduction(reduction)
2228                   .margin(margin)
2229                   .swap(swap);
2230 
2231           auto anchor = torch::randn(
2232               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2233           auto positive = torch::randn(
2234               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2235           auto negative = torch::randn(
2236               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2237 
2238           TripletMarginWithDistanceLoss distanceLoss(moduleOptions);
2239 
2240           auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
2241           auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
2242           auto functionOutput =
2243               torch::nn::functional::triplet_margin_with_distance_loss(
2244                   anchor, positive, negative, functionOptions);
2245 
2246           ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
2247           ASSERT_TRUE(
2248               moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
2249         }
2250       }
2251     }
2252   }
2253 }
2254 
TEST_F(ModulesTest,NLLLoss)2255 TEST_F(ModulesTest, NLLLoss) {
2256   NLLLoss loss;
2257   auto input = torch::tensor(
2258       {{-0.1315, -3.1315, -2.5315},
2259        {-3.7038, -0.1038, -2.6038},
2260        {-2.3422, -1.3422, -0.4422}},
2261       torch::dtype(torch::kFloat).requires_grad(true));
2262   auto target = torch::tensor({1, 0, 2}, torch::kLong);
2263   auto output = loss->forward(input, target);
2264   auto expected = torch::tensor(2.4258, torch::kFloat);
2265   auto s = output.sum();
2266   s.backward();
2267 
2268   ASSERT_TRUE(output.allclose(expected, 1e-04));
2269   ASSERT_TRUE(
2270       NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean))
2271           ->forward(input, target)
2272           .allclose(expected, 1e-04));
2273 }
2274 
TEST_F(ModulesTest,CrossEntropyLoss)2275 TEST_F(ModulesTest, CrossEntropyLoss) {
2276   CrossEntropyLoss loss;
2277   auto input = torch::tensor(
2278       {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2279   auto target = torch::tensor({0, 1}, torch::kLong);
2280   auto output = loss->forward(input, target);
2281   auto expected = torch::tensor(0.6931, torch::kFloat);
2282   auto s = output.sum();
2283   s.backward();
2284 
2285   ASSERT_TRUE(output.allclose(expected, 1e-04));
2286   ASSERT_EQ(input.sizes(), input.grad().sizes());
2287   ASSERT_TRUE(
2288       CrossEntropyLoss(
2289           CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean))
2290           ->forward(input, target)
2291           .allclose(expected, 1e-04));
2292 
2293   // label smoothing with class indices
2294   loss = CrossEntropyLoss(
2295       CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean));
2296   input = torch::tensor(
2297       {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2298   target = torch::tensor({0, 1}, torch::kLong);
2299   output = loss->forward(input, target);
2300   expected = torch::tensor(0.3326, torch::kFloat);
2301   s = output.sum();
2302   s.backward();
2303 
2304   ASSERT_TRUE(output.allclose(expected, 1e-04));
2305   ASSERT_EQ(input.sizes(), input.grad().sizes());
2306 
2307   // label smoothing with with target probabilities
2308   loss = CrossEntropyLoss(
2309       CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean));
2310   input = torch::tensor(
2311       {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2312   target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
2313   output = loss->forward(input, target);
2314   expected = torch::tensor(0.5701, torch::kFloat);
2315   s = output.sum();
2316   s.backward();
2317 
2318   ASSERT_TRUE(output.allclose(expected, 1e-04));
2319   ASSERT_EQ(input.sizes(), input.grad().sizes());
2320 }
2321 
TEST_F(ModulesTest,CosineSimilarity)2322 TEST_F(ModulesTest, CosineSimilarity) {
2323   CosineSimilarity cos(CosineSimilarityOptions().dim(1));
2324   auto input1 = torch::tensor(
2325       {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2326   auto input2 = torch::tensor(
2327       {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2328   auto output = cos->forward(input1, input2);
2329   auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
2330   auto s = output.sum();
2331   s.backward();
2332 
2333   ASSERT_TRUE(output.allclose(expected, 1e-04));
2334   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2335 }
2336 
TEST_F(ModulesTest,SoftMarginLossDefaultOptions)2337 TEST_F(ModulesTest, SoftMarginLossDefaultOptions) {
2338   SoftMarginLoss loss;
2339   auto input = torch::tensor(
2340       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2341   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2342   auto output = loss->forward(input, target);
2343   auto expected = torch::tensor({1.3767317}, torch::kFloat);
2344   auto s = output.sum();
2345   s.backward();
2346 
2347   ASSERT_TRUE(output.allclose(expected));
2348   ASSERT_EQ(input.sizes(), input.grad().sizes());
2349 }
2350 
TEST_F(ModulesTest,MultiLabelSoftMarginLossDefaultOptions)2351 TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
2352   MultiLabelSoftMarginLoss loss;
2353   auto input = torch::tensor(
2354       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2355       torch::dtype(torch::kFloat).requires_grad(true));
2356   auto target =
2357       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2358   auto output = loss->forward(input, target);
2359   auto expected = torch::tensor({0.7608436}, torch::kFloat);
2360   auto s = output.sum();
2361   s.backward();
2362 
2363   ASSERT_TRUE(output.allclose(expected));
2364   ASSERT_EQ(input.sizes(), input.grad().sizes());
2365 }
2366 
TEST_F(ModulesTest,SoftMarginLossNoReduction)2367 TEST_F(ModulesTest, SoftMarginLossNoReduction) {
2368   SoftMarginLoss loss(torch::kNone);
2369   auto input = torch::tensor(
2370       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2371   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2372   auto output = loss->forward(input, target);
2373   auto expected = torch::tensor(
2374       {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
2375   auto s = output.sum();
2376   s.backward();
2377 
2378   ASSERT_TRUE(output.allclose(expected));
2379   ASSERT_EQ(input.sizes(), input.grad().sizes());
2380 }
2381 
TEST_F(ModulesTest,MultiLabelSoftMarginLossWeightedNoReduction)2382 TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
2383   auto input = torch::tensor(
2384       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2385       torch::dtype(torch::kFloat).requires_grad(true));
2386   auto target =
2387       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2388   auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
2389   auto options =
2390       MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight);
2391   MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options);
2392   auto output = loss->forward(input, target);
2393   auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
2394   auto s = output.sum();
2395   s.backward();
2396 
2397   ASSERT_TRUE(output.allclose(expected));
2398   ASSERT_EQ(input.sizes(), input.grad().sizes());
2399 }
2400 
TEST_F(ModulesTest,PairwiseDistance)2401 TEST_F(ModulesTest, PairwiseDistance) {
2402   PairwiseDistance dist(PairwiseDistanceOptions().p(1));
2403   auto input1 = torch::tensor(
2404       {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2405   auto input2 = torch::tensor(
2406       {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2407   auto output = dist->forward(input1, input2);
2408   auto expected = torch::tensor({6, 6}, torch::kFloat);
2409   auto s = output.sum();
2410   s.backward();
2411 
2412   ASSERT_TRUE(output.allclose(expected));
2413   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2414 }
2415 
TEST_F(ModulesTest,ELU)2416 TEST_F(ModulesTest, ELU) {
2417   const auto size = 3;
2418   for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2419     for (const auto inplace : {false, true}) {
2420       ELU model{ELUOptions().alpha(alpha).inplace(inplace)};
2421       auto x = torch::linspace(-10.0, 10.0, size * size * size);
2422       x.resize_({size, size, size});
2423       if (!inplace) {
2424         x.requires_grad_(true);
2425       }
2426       auto x_orig = x.clone();
2427       auto y = model(x);
2428       torch::Tensor s = y.sum();
2429 
2430       ASSERT_EQ(s.ndimension(), 0);
2431 
2432       ASSERT_EQ(y.ndimension(), 3);
2433       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2434       auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2435           torch::min(torch::zeros_like(x_orig),
2436                      alpha * (torch::exp(x_orig) - 1.0));
2437       ASSERT_TRUE(torch::allclose(y, y_exp));
2438       if (inplace) {
2439         ASSERT_TRUE(torch::allclose(x, y_exp));
2440       } else {
2441         s.backward();
2442       }
2443     }
2444   }
2445 }
2446 
TEST_F(ModulesTest,SELU)2447 TEST_F(ModulesTest, SELU) {
2448   for (const auto inplace : {false, true}) {
2449     SELU model(inplace);
2450     auto input = torch::randn({5, 5});
2451     if (!inplace) {
2452       input.requires_grad_(true);
2453     }
2454     auto input_orig = input.clone();
2455     auto output = model->forward(input);
2456     const double scale = 1.0507009873554804934193349852946;
2457     const double alpha = 1.6732632423543772848170429916717;
2458     auto zero = torch::zeros_like(input);
2459     auto expected = scale *
2460         (torch::max(zero, input_orig) +
2461          torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
2462     auto s = output.sum();
2463 
2464     ASSERT_EQ(s.ndimension(), 0);
2465     ASSERT_TRUE(output.allclose(expected));
2466     if (inplace) {
2467       ASSERT_TRUE(input.allclose(expected));
2468     } else {
2469       s.backward();
2470     }
2471   }
2472 }
2473 
TEST_F(ModulesTest,Hardshrink)2474 TEST_F(ModulesTest, Hardshrink) {
2475   const auto size = 3;
2476   for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
2477     Hardshrink model{HardshrinkOptions().lambda(lambda)};
2478     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2479     x.resize_({size, size, size}).set_requires_grad(true);
2480     auto y = model(x);
2481     torch::Tensor s = y.sum();
2482 
2483     s.backward();
2484     ASSERT_EQ(s.ndimension(), 0);
2485     ASSERT_EQ(y.ndimension(), 3);
2486     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2487     auto y_exp = (x.abs() > lambda) * x;
2488     ASSERT_TRUE(torch::allclose(y, y_exp));
2489   }
2490 }
2491 
TEST_F(ModulesTest,Hardtanh)2492 TEST_F(ModulesTest, Hardtanh) {
2493   const auto size = 3;
2494   for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
2495     for (const auto max_val : {0.42, 1.0, 4.2}) {
2496       for (const auto inplace : {false, true}) {
2497         Hardtanh model{
2498             HardtanhOptions().min_val(min_val).max_val(max_val).inplace(
2499                 inplace)};
2500         auto x = torch::linspace(-10.0, 10.0, size * size * size);
2501         x.resize_({size, size, size});
2502         if (!inplace) {
2503           x.requires_grad_(true);
2504         }
2505         auto x_orig = x.clone();
2506         auto y = model(x);
2507         torch::Tensor s = y.sum();
2508 
2509         ASSERT_EQ(s.ndimension(), 0);
2510         ASSERT_EQ(y.ndimension(), 3);
2511         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2512         auto y_exp = (x_orig < min_val) * min_val +
2513             ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig +
2514             (x_orig > max_val) * max_val;
2515         ASSERT_TRUE(torch::allclose(y, y_exp));
2516         if (inplace) {
2517           ASSERT_TRUE(torch::allclose(x, y_exp));
2518         } else {
2519           s.backward();
2520         }
2521       }
2522     }
2523   }
2524 }
2525 
TEST_F(ModulesTest,HardtanhMinValGEMaxVal)2526 TEST_F(ModulesTest, HardtanhMinValGEMaxVal) {
2527   ASSERT_THROWS_WITH(
2528       Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)},
2529       "max_val must be greater than min_val");
2530   ASSERT_THROWS_WITH(
2531       Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)},
2532       "max_val must be greater than min_val");
2533 
2534   Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)};
2535   ht->options.min_val(0.42);
2536   ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2537   ht->options.max_val(-0.42);
2538   ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2539 }
2540 
TEST_F(ModulesTest,LeakyReLU)2541 TEST_F(ModulesTest, LeakyReLU) {
2542   const auto size = 3;
2543   for (const auto inplace : {false, true}) {
2544     for (const auto negative_slope : {0.0, 0.42, 1.0}) {
2545       for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2546         LeakyReLU model{
2547             LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
2548         auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2549         x.resize_({size, size, size});
2550         if (!inplace) {
2551           x.requires_grad_(true);
2552         }
2553         auto x_orig = x.clone();
2554         auto y = model(x);
2555         torch::Tensor s = y.sum();
2556 
2557         ASSERT_EQ(s.ndimension(), 0);
2558         ASSERT_EQ(y.ndimension(), 3);
2559         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2560         auto y_exp =
2561             (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
2562         ASSERT_TRUE(torch::allclose(y, y_exp));
2563         if (inplace) {
2564           ASSERT_TRUE(torch::allclose(x, y_exp));
2565         } else {
2566           s.backward();
2567         }
2568       }
2569     }
2570   }
2571 }
2572 
TEST_F(ModulesTest,LogSigmoid)2573 TEST_F(ModulesTest, LogSigmoid) {
2574   const auto size = 3;
2575   LogSigmoid model;
2576   auto x = torch::linspace(-10.0, 10.0, size * size * size);
2577   x.resize_({size, size, size}).set_requires_grad(true);
2578   auto y = model(x);
2579   torch::Tensor s = y.sum();
2580 
2581   s.backward();
2582   ASSERT_EQ(s.ndimension(), 0);
2583 
2584   ASSERT_EQ(y.ndimension(), 3);
2585   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2586   auto y_exp = torch::log(
2587       torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
2588   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
2589 }
2590 
TEST_F(ModulesTest,Softmax)2591 TEST_F(ModulesTest, Softmax) {
2592   Softmax m(/*dim=*/1);
2593   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2594   auto output = m(input);
2595   auto sum = torch::sum(torch::exp(input), 1);
2596 
2597   for (const auto i : c10::irange(2)) {
2598     auto expected = torch::exp(input[i]) / sum[i];
2599     ASSERT_TRUE(torch::allclose(output[i], expected));
2600   }
2601 }
2602 
TEST_F(ModulesTest,Softmin)2603 TEST_F(ModulesTest, Softmin) {
2604   Softmin m(/*dim=*/1);
2605   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2606   auto output = m(input);
2607   auto sum = torch::sum(torch::exp(-input), 1);
2608 
2609   for (const auto i : c10::irange(2)) {
2610     auto expected = torch::exp(-input[i]) / sum[i];
2611     ASSERT_TRUE(torch::allclose(output[i], expected));
2612   }
2613 }
2614 
TEST_F(ModulesTest,LogSoftmax)2615 TEST_F(ModulesTest, LogSoftmax) {
2616   LogSoftmax m(/*dim=*/1);
2617   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2618   auto output = m(input);
2619   auto sum = torch::sum(torch::exp(input), 1);
2620 
2621   for (const auto i : c10::irange(2)) {
2622     auto expected = torch::log(torch::exp(input[i]) / sum[i]);
2623     ASSERT_TRUE(torch::allclose(output[i], expected));
2624   }
2625 }
2626 
TEST_F(ModulesTest,AdaptiveLogSoftmaxWithLoss)2627 TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
2628   {
2629     // log_probs actually returns log_proba
2630     AdaptiveLogSoftmaxWithLoss asfm(
2631         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2632     auto x = torch::randn({4, 8});
2633     auto logprob_out = asfm->log_prob(x);
2634     ASSERT_TRUE(
2635         torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
2636   }
2637   {
2638     // test predict
2639     AdaptiveLogSoftmaxWithLoss asfm(
2640         AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
2641             .div_value(2.)
2642             .head_bias(true));
2643     auto x = torch::randn({64, 8});
2644     auto logprob_out = asfm->log_prob(x);
2645     auto predict_out = asfm->predict(x);
2646     ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
2647   }
2648   {
2649     // cluster sizes
2650     AdaptiveLogSoftmaxWithLoss asfm(
2651         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2652     auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
2653     auto y = torch::tensor({0, 17}, torch::kLong);
2654     auto asm_out = asfm(x, y);
2655     ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
2656   }
2657   {
2658     // forward returns the same thing as log_probs
2659     AdaptiveLogSoftmaxWithLoss asfm(
2660         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2661     auto x = torch::randn({4, 8});
2662     auto logprob_out = asfm->log_prob(x);
2663     NLLLoss nll_loss;
2664 
2665     for (const auto v : c10::irange(4)) {
2666       auto y = torch::full({4}, v, torch::kLong);
2667       auto asm_out = asfm(x, y);
2668       auto out = asm_out.output;
2669       auto loss = torch::tensor(asm_out.loss, torch::kFloat);
2670       auto expected = nll_loss->forward(logprob_out, y);
2671 
2672       ASSERT_TRUE(torch::allclose(loss, expected));
2673       ASSERT_TRUE(torch::allclose(
2674           out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
2675     }
2676   }
2677   {
2678     // test no batch dim
2679     AdaptiveLogSoftmaxWithLoss asfm(
2680         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2681     auto x = torch::randn({1, 16});
2682     auto y = torch::tensor({17});
2683     auto x2 = x.squeeze(0);
2684     auto y2 = y.squeeze(0);
2685     ASSERT_TRUE(
2686         torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
2687   }
2688   {
2689     // test div_value
2690     auto options =
2691         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
2692     ASSERT_THROWS_WITH(
2693         AdaptiveLogSoftmaxWithLoss(options),
2694         "div_value should not be equal to 0");
2695 
2696     options =
2697         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
2698     ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
2699   }
2700 }
2701 
TEST_F(ModulesTest,Softmax2d)2702 TEST_F(ModulesTest, Softmax2d) {
2703   Softmax2d m;
2704   auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
2705   auto output = m(input);
2706   auto sum = torch::sum(torch::exp(input), 1);
2707 
2708   for (const auto i : c10::irange(1)) {
2709     for (const auto j : c10::irange(2)) {
2710       for (const auto k : c10::irange(3)) {
2711         for (const auto l : c10::irange(4)) {
2712           auto expected = torch::exp(input[i][j][k][l]) / sum[i][k][l];
2713           ASSERT_TRUE(torch::allclose(output[i][j][k][l], expected));
2714         }
2715       }
2716     }
2717   }
2718 }
2719 
TEST_F(ModulesTest,PReLU)2720 TEST_F(ModulesTest, PReLU) {
2721   const auto num_parameters = 42;
2722   const auto init = 0.42;
2723 
2724   PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)};
2725 
2726   ASSERT_EQ(model->weight.sizes(), std::vector<int64_t>({num_parameters}));
2727   ASSERT_TRUE(
2728       torch::allclose(model->weight, torch::full(num_parameters, init)));
2729 
2730   const auto x = torch::rand({100, num_parameters}) * 200 - 100;
2731   const auto y = model(x);
2732   const auto s = y.sum();
2733 
2734   s.backward();
2735   ASSERT_EQ(s.ndimension(), 0);
2736 
2737   ASSERT_EQ(y.ndimension(), x.ndimension());
2738   ASSERT_EQ(y.sizes(), x.sizes());
2739   const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x;
2740   ASSERT_TRUE(torch::allclose(y, y_exp));
2741 }
2742 
TEST_F(ModulesTest,ReLU)2743 TEST_F(ModulesTest, ReLU) {
2744   for (const auto inplace : {false, true}) {
2745     const auto size = 3;
2746     ReLU model(inplace);
2747     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2748     x.resize_({size, size, size});
2749     if (!inplace) {
2750       x.requires_grad_(true);
2751     }
2752     auto x_orig = x.clone();
2753     auto y = model(x);
2754     torch::Tensor s = y.sum();
2755 
2756     ASSERT_EQ(s.ndimension(), 0);
2757     ASSERT_EQ(y.ndimension(), 3);
2758     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2759     auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig;
2760     ASSERT_TRUE(torch::allclose(y, y_exp));
2761     if (inplace) {
2762       ASSERT_TRUE(torch::allclose(x, y_exp));
2763     } else {
2764       s.backward();
2765     }
2766   }
2767 }
2768 
TEST_F(ModulesTest,ReLU6)2769 TEST_F(ModulesTest, ReLU6) {
2770   for (const auto inplace : {false, true}) {
2771     const auto size = 3;
2772     ReLU6 model(inplace);
2773     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2774     x.resize_({size, size, size});
2775     if (!inplace) {
2776       x.requires_grad_(true);
2777     }
2778     auto x_orig = x.clone();
2779     auto y = model(x);
2780     torch::Tensor s = y.sum();
2781 
2782     ASSERT_EQ(s.ndimension(), 0);
2783     ASSERT_EQ(y.ndimension(), 3);
2784     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2785     auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig +
2786         (x_orig > 6) * 6;
2787     ASSERT_TRUE(torch::allclose(y, y_exp));
2788     if (inplace) {
2789       ASSERT_TRUE(torch::allclose(x, y_exp));
2790     } else {
2791       s.backward();
2792     }
2793   }
2794 }
2795 
TEST_F(ModulesTest,RReLU)2796 TEST_F(ModulesTest, RReLU) {
2797   const auto size = 3;
2798   for (const auto lower : {0.01, 0.1, 0.2}) {
2799     for (const auto upper : {0.3, 0.4, 0.5}) {
2800       for (const auto inplace : {false, true}) {
2801         for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2802           RReLU model{
2803               RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
2804           auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2805           x.resize_({size, size, size});
2806           if (!inplace) {
2807             x.requires_grad_(true);
2808           }
2809           auto x_orig = x.clone();
2810           auto y = model(x);
2811           torch::Tensor s = y.sum();
2812 
2813           ASSERT_EQ(s.ndimension(), 0);
2814           ASSERT_EQ(y.ndimension(), 3);
2815           ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2816           auto z =
2817               ((x_orig >= 0) * (x_orig == y) +
2818                (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) *
2819               1.0;
2820           ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
2821           if (inplace) {
2822             ASSERT_TRUE(torch::allclose(x, y));
2823           } else {
2824             s.backward();
2825           }
2826         }
2827       }
2828     }
2829   }
2830 }
2831 
TEST_F(ModulesTest,CELU)2832 TEST_F(ModulesTest, CELU) {
2833   const auto size = 3;
2834   for (const auto inplace : {false, true}) {
2835     for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
2836       CELU model{CELUOptions().alpha(alpha).inplace(inplace)};
2837       auto x = torch::linspace(-10.0, 10.0, size * size * size);
2838       x.resize_({size, size, size});
2839       if (!inplace) {
2840         x.requires_grad_(true);
2841       }
2842       auto x_orig = x.clone();
2843       auto y = model(x);
2844       torch::Tensor s = y.sum();
2845 
2846       ASSERT_EQ(s.ndimension(), 0);
2847       ASSERT_EQ(y.ndimension(), 3);
2848       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2849       auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2850           torch::min(torch::zeros_like(x_orig),
2851                      alpha * (torch::exp(x_orig / alpha) - 1.0));
2852       ASSERT_TRUE(torch::allclose(y, y_exp));
2853       if (inplace) {
2854         ASSERT_TRUE(torch::allclose(x, y_exp));
2855       } else {
2856         s.backward();
2857       }
2858     }
2859   }
2860 }
2861 
TEST_F(ModulesTest,GLU)2862 TEST_F(ModulesTest, GLU) {
2863   int64_t dim = 1;
2864   GLU model(dim);
2865   auto input = torch::randn({4, 2}, torch::requires_grad());
2866   auto output = model->forward(input);
2867   auto input_size = input.sizes()[dim] / 2;
2868   auto first_half = input.narrow(dim, 0, input_size);
2869   auto second_half = input.narrow(dim, input_size, input_size);
2870   auto expected = first_half * torch::sigmoid(second_half);
2871   auto s = output.sum();
2872   s.backward();
2873 
2874   ASSERT_EQ(s.ndimension(), 0);
2875   ASSERT_TRUE(output.allclose(expected));
2876 
2877   GLU model_default_options;
2878   ASSERT_TRUE(model_default_options->forward(input).allclose(expected));
2879 }
2880 
TEST_F(ModulesTest,GELU)2881 TEST_F(ModulesTest, GELU) {
2882   GELU model(GELUOptions().approximate("none"));
2883   const auto x = torch::linspace(-3.0, 3.0, 100);
2884   const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
2885   const auto y = model(x);
2886   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2887 }
2888 
TEST_F(ModulesTest,TanhGELU)2889 TEST_F(ModulesTest, TanhGELU) {
2890   GELU model(GELUOptions().approximate("tanh"));
2891   const auto x = torch::linspace(-3.0, 3.0, 100);
2892   const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
2893   const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
2894   const auto y = model(x);
2895   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2896 }
2897 
2898 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_F(ModulesTest,Mish)2899 TEST_F(ModulesTest, Mish) {
2900   Mish model;
2901   auto x = torch::randn(100) * 10;
2902   auto y_exp = x * x.exp().log1p().tanh();
2903   auto y = model(x);
2904 
2905   ASSERT_TRUE(torch::allclose(y, y_exp));
2906 }
2907 
TEST_F(ModulesTest,Sigmoid)2908 TEST_F(ModulesTest, Sigmoid) {
2909   Sigmoid model;
2910   auto x = torch::randn(100) * 10;
2911   auto y_exp = 1 / (1 + torch::exp(-x));
2912   auto y = model(x);
2913 
2914   ASSERT_TRUE(torch::allclose(y, y_exp));
2915 }
2916 
TEST_F(ModulesTest,PixelShuffle)2917 TEST_F(ModulesTest, PixelShuffle) {
2918   PixelShuffle module(/*upscale_factor=*/2);
2919   auto x = torch::tensor(
2920       {{{{-17, 19}, {-1, 2}},
2921         {{7, 14}, {-3, 1}},
2922         {{0, -2}, {-12, 14}},
2923         {{-15, 0}, {-3, 9}}}},
2924       torch::kFloat);
2925   auto y_exp = torch::tensor(
2926       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2927       torch::kFloat);
2928   auto y = module(x);
2929 
2930   ASSERT_EQ(y.ndimension(), 4);
2931   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
2932   ASSERT_TRUE(y.allclose(y_exp));
2933 }
2934 
TEST_F(ModulesTest,PixelUnshuffle)2935 TEST_F(ModulesTest, PixelUnshuffle) {
2936   PixelUnshuffle module(/*downscale_factor=*/2);
2937   auto x = torch::tensor(
2938       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2939       torch::kFloat);
2940   auto y_exp = torch::tensor(
2941       {{{{-17, 19}, {-1, 2}},
2942         {{7, 14}, {-3, 1}},
2943         {{0, -2}, {-12, 14}},
2944         {{-15, 0}, {-3, 9}}}},
2945       torch::kFloat);
2946   auto y = module(x);
2947 
2948   ASSERT_EQ(y.ndimension(), 4);
2949   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
2950   ASSERT_TRUE(y.allclose(y_exp));
2951 }
2952 
TEST_F(ModulesTest,Softplus)2953 TEST_F(ModulesTest, Softplus) {
2954   const auto size = 3;
2955   for (const auto beta : {0.5, 1.0, 2.0}) {
2956     for (const auto threshold : {1.0, 3.0, 5.0}) {
2957       Softplus model{SoftplusOptions().beta(beta).threshold(threshold)};
2958       auto x = torch::linspace(-3.0, 3.0, 61);
2959       x.resize_({size, size, size});
2960       auto y_exp =
2961           (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
2962           (x > threshold) * x;
2963       auto y = model(x);
2964 
2965       ASSERT_EQ(y.ndimension(), 3);
2966       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2967       ASSERT_TRUE(torch::allclose(y, y_exp));
2968     }
2969   }
2970 }
2971 
TEST_F(ModulesTest,Softshrink)2972 TEST_F(ModulesTest, Softshrink) {
2973   const auto size = 3;
2974   for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2975     Softshrink model{/*lambda=*/lambda};
2976     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2977     x.resize_({size, size, size}).set_requires_grad(true);
2978     auto y = model(x);
2979     torch::Tensor s = y.sum();
2980 
2981     s.backward();
2982     ASSERT_EQ(s.ndimension(), 0);
2983 
2984     ASSERT_EQ(y.ndimension(), 3);
2985     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2986     auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
2987     ASSERT_TRUE(torch::allclose(y, y_exp));
2988   }
2989 }
2990 
TEST_F(ModulesTest,Softsign)2991 TEST_F(ModulesTest, Softsign) {
2992   Softsign model;
2993   auto x = torch::randn(100) * 10;
2994   auto y_exp = x / (1 + x.abs());
2995   auto y = model(x);
2996 
2997   ASSERT_TRUE(torch::allclose(y, y_exp));
2998 }
2999 
TEST_F(ModulesTest,Tanh)3000 TEST_F(ModulesTest, Tanh) {
3001   Tanh model;
3002   auto x = torch::randn(100) * 10;
3003   auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp());
3004   auto y = model(x);
3005 
3006   ASSERT_TRUE(torch::allclose(y, y_exp));
3007 }
3008 
TEST_F(ModulesTest,Tanhshrink)3009 TEST_F(ModulesTest, Tanhshrink) {
3010   Tanhshrink model;
3011   auto x = torch::randn(100) * 10;
3012   auto y_exp = x - x.tanh();
3013   auto y = model(x);
3014 
3015   ASSERT_TRUE(torch::allclose(y, y_exp));
3016 }
3017 
TEST_F(ModulesTest,Threshold)3018 TEST_F(ModulesTest, Threshold) {
3019   const auto size = 3;
3020   for (const auto threshold : {0.5, 1.0, 2.0}) {
3021     for (const auto value : {0.5, 1.0, 2.0}) {
3022       for (const auto inplace : {false, true}) {
3023         Threshold model{ThresholdOptions(threshold, value).inplace(inplace)};
3024         auto x = torch::linspace(-3.0, 3.0, 61);
3025         x.resize_({size, size, size});
3026         auto x_orig = x.clone();
3027         auto y_exp =
3028             (x_orig <= threshold) * value + (x_orig > threshold) * x_orig;
3029         auto y = model(x);
3030 
3031         ASSERT_EQ(y.ndimension(), 3);
3032         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
3033         ASSERT_TRUE(torch::allclose(y, y_exp));
3034         if (inplace) {
3035           ASSERT_TRUE(torch::allclose(x, y_exp));
3036         }
3037       }
3038     }
3039   }
3040 }
3041 
TEST_F(ModulesTest,Upsampling1D)3042 TEST_F(ModulesTest, Upsampling1D) {
3043   {
3044     Upsample model(UpsampleOptions()
3045                        .size(std::vector<int64_t>({4}))
3046                        .mode(torch::kNearest));
3047     auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3048     auto output = model->forward(input);
3049     auto expected = torch::ones({1, 1, 4});
3050     auto s = output.sum();
3051     s.backward();
3052 
3053     ASSERT_EQ(s.ndimension(), 0);
3054     ASSERT_TRUE(output.allclose(expected));
3055   }
3056   {
3057     for (const auto align_corners : {true, false}) {
3058       // test float scale factor up & down sampling
3059       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3060         Upsample model(UpsampleOptions()
3061                            .scale_factor(std::vector<double>({scale_factor}))
3062                            .mode(torch::kLinear)
3063                            .align_corners(align_corners));
3064         auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3065         auto output = model->forward(input);
3066         auto expected_size =
3067             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3068         auto expected = torch::ones({1, 1, expected_size});
3069         auto s = output.sum();
3070         s.backward();
3071 
3072         ASSERT_EQ(s.ndimension(), 0);
3073         ASSERT_TRUE(output.allclose(expected));
3074       }
3075     }
3076   }
3077   {
3078     // linear (1D) upsampling spatial invariance
3079     Upsample model(UpsampleOptions()
3080                        .scale_factor(std::vector<double>({3}))
3081                        .mode(torch::kLinear)
3082                        .align_corners(false));
3083     auto input = torch::zeros({1, 1, 9});
3084     input.narrow(2, 0, 4).normal_();
3085     auto output = model->forward(input);
3086     auto expected = model->forward(input.narrow(2, 0, 5));
3087 
3088     ASSERT_TRUE(torch::allclose(output.narrow(2, 0, 15), expected));
3089   }
3090 }
3091 
TEST_F(ModulesTest,Upsampling2D)3092 TEST_F(ModulesTest, Upsampling2D) {
3093   {
3094     Upsample model(UpsampleOptions()
3095                        .size(std::vector<int64_t>({4, 4}))
3096                        .mode(torch::kNearest));
3097     auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3098     auto output = model->forward(input);
3099     auto expected = torch::ones({1, 1, 4, 4});
3100     auto s = output.sum();
3101     s.backward();
3102 
3103     ASSERT_EQ(s.ndimension(), 0);
3104     ASSERT_TRUE(output.allclose(expected));
3105   }
3106   {
3107     for (const auto align_corners : {true, false}) {
3108       // test float scale factor up & down sampling
3109       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3110         Upsample model(
3111             UpsampleOptions()
3112                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3113                 .mode(torch::kBilinear)
3114                 .align_corners(align_corners));
3115         auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3116         auto output = model->forward(input);
3117         auto expected_size =
3118             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3119         auto expected = torch::ones({1, 1, expected_size, expected_size});
3120         auto s = output.sum();
3121         s.backward();
3122 
3123         ASSERT_EQ(s.ndimension(), 0);
3124         ASSERT_TRUE(output.allclose(expected));
3125       }
3126     }
3127   }
3128   {
3129     for (const auto align_corners : {true, false}) {
3130       // test float scale factor up & down sampling
3131       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3132         Upsample model(
3133             UpsampleOptions()
3134                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3135                 .mode(torch::kBicubic)
3136                 .align_corners(align_corners));
3137         auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3138         auto output = model->forward(input);
3139         auto expected_size =
3140             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3141         auto expected = torch::ones({1, 1, expected_size, expected_size});
3142         auto s = output.sum();
3143         s.backward();
3144 
3145         ASSERT_EQ(s.ndimension(), 0);
3146         ASSERT_TRUE(output.allclose(expected));
3147       }
3148     }
3149   }
3150 }
3151 
TEST_F(ModulesTest,Upsampling3D)3152 TEST_F(ModulesTest, Upsampling3D) {
3153   {
3154     Upsample model(UpsampleOptions()
3155                        .size(std::vector<int64_t>({4, 4, 4}))
3156                        .mode(torch::kNearest));
3157     auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3158     auto output = model->forward(input);
3159     auto expected = torch::ones({1, 1, 4, 4, 4});
3160     auto s = output.sum();
3161     s.backward();
3162 
3163     ASSERT_EQ(s.ndimension(), 0);
3164     ASSERT_TRUE(output.allclose(expected));
3165   }
3166   {
3167     for (const auto align_corners : {true, false}) {
3168       // test float scale factor up & down sampling
3169       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3170         Upsample model(UpsampleOptions()
3171                            .scale_factor(std::vector<double>(
3172                                {scale_factor, scale_factor, scale_factor}))
3173                            .mode(torch::kTrilinear)
3174                            .align_corners(align_corners));
3175         auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3176         auto output = model->forward(input);
3177         auto expected_size =
3178             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3179         auto expected =
3180             torch::ones({1, 1, expected_size, expected_size, expected_size});
3181         auto s = output.sum();
3182         s.backward();
3183 
3184         ASSERT_EQ(s.ndimension(), 0);
3185         ASSERT_TRUE(output.allclose(expected));
3186       }
3187     }
3188   }
3189 }
3190 
TEST_F(ModulesTest,CTCLoss)3191 TEST_F(ModulesTest, CTCLoss) {
3192   CTCLoss loss{CTCLossOptions().reduction(torch::kNone)};
3193   const auto target_lengths = torch::tensor({0, 0, 0});
3194   const auto input_lengths = torch::tensor({50, 50, 50});
3195   const auto targets =
3196       torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
3197   const auto log_probs =
3198       torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
3199   const auto output =
3200       loss->forward(log_probs, targets, input_lengths, target_lengths);
3201   ASSERT_TRUE(output.ge(0).all().item<bool>());
3202   ASSERT_TRUE(torch::allclose(
3203       -log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
3204 }
3205 
TEST_F(ModulesTest,PoissonNLLLoss)3206 TEST_F(ModulesTest, PoissonNLLLoss) {
3207   const auto input = torch::tensor({0.5, 1.5, 2.5});
3208   const auto target = torch::tensor({1., 2., 3.});
3209   const auto component_wise_loss = torch::exp(input) - target * input;
3210   {
3211     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)};
3212     ASSERT_TRUE(
3213         torch::allclose(component_wise_loss, loss->forward(input, target)));
3214   }
3215   {
3216     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)};
3217     ASSERT_TRUE(torch::allclose(
3218         torch::sum(component_wise_loss), loss->forward(input, target)));
3219   }
3220   {
3221     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)};
3222     ASSERT_TRUE(torch::allclose(
3223         torch::mean(component_wise_loss), loss->forward(input, target)));
3224   }
3225 }
3226 
TEST_F(ModulesTest,MarginRankingLoss)3227 TEST_F(ModulesTest, MarginRankingLoss) {
3228   {
3229     MarginRankingLoss loss;
3230     const auto input1 = torch::randn(15) * 10;
3231     const auto input2 = torch::randn(15) * 10;
3232     const auto target = torch::randn(15).sign();
3233     ASSERT_TRUE(torch::allclose(
3234         loss->forward(input1, input2, target),
3235         (-target * (input1 - input2)).clamp(0).mean()));
3236   }
3237   {
3238     MarginRankingLoss loss{
3239         MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)};
3240     const auto input1 = torch::randn(15) * 10;
3241     const auto input2 = torch::randn(15) * 10;
3242     const auto target = torch::randn(15).sign();
3243     const auto margin = 0.5;
3244     ASSERT_TRUE(torch::allclose(
3245         loss->forward(input1, input2, target),
3246         (-target * (input1 - input2) + margin).clamp(0).sum()));
3247   }
3248   {
3249     MarginRankingLoss loss{
3250         MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)};
3251     const auto input1 = torch::randn(15) * 10;
3252     const auto input2 = torch::randn(15) * 10;
3253     const auto target = torch::randn(15).sign();
3254     const auto margin = 0.5;
3255     ASSERT_TRUE(torch::allclose(
3256         loss->forward(input1, input2, target),
3257         (-target * (input1 - input2) + margin).clamp(0).mean()));
3258   }
3259 }
3260 
TEST_F(ModulesTest,BCEWithLogitsLoss)3261 TEST_F(ModulesTest, BCEWithLogitsLoss) {
3262   { // test BCE with logits raises if target and input are different size
3263     {
3264       const auto target = torch::rand(5);
3265       const auto input = torch::rand({5, 1});
3266       ASSERT_THROWS_WITH(
3267           BCEWithLogitsLoss()(input, target), "must be the same as input size");
3268     }
3269     {
3270       const auto target = torch::rand({5, 1});
3271       const auto input = torch::rand(5);
3272       ASSERT_THROWS_WITH(
3273           BCEWithLogitsLoss()(input, target), "must be the same as input size");
3274     }
3275   }
3276   { // test BCE with logits gives same result as sigmoid and bce loss
3277     auto sigmoid = Sigmoid();
3278 
3279     auto target = torch::rand({64, 4});
3280     auto output = torch::rand({64, 4}) - 0.5;
3281 
3282     ASSERT_TRUE(torch::allclose(
3283         BCEWithLogitsLoss()(output, target),
3284         BCELoss()(sigmoid(output), target)));
3285 
3286     auto weight = torch::rand(4);
3287     ASSERT_TRUE(torch::allclose(
3288         BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3289             output, target),
3290         BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3291 
3292     target = torch::zeros({4, 1}, torch::kFloat);
3293     output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3294 
3295     ASSERT_TRUE(torch::allclose(
3296         BCEWithLogitsLoss()(output, target),
3297         BCELoss()(sigmoid(output), target)));
3298 
3299     ASSERT_TRUE(torch::allclose(
3300         BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))(
3301             output, target),
3302         BCELoss(BCELossOptions().reduction(torch::kNone))(
3303             sigmoid(output), target)));
3304 
3305     weight = torch::rand({1}, torch::kFloat);
3306     ASSERT_TRUE(torch::allclose(
3307         BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3308             output, target),
3309         BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3310   }
3311   { // test BCE with logits has correct grad at zero
3312     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3313     const auto target = torch::zeros({3, 1});
3314     BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))(
3315         output, target)
3316         .backward();
3317     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3318     ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3319   }
3320   { // test BCE with logits broadcasts weights
3321     const auto target = torch::rand({16, 4});
3322     const auto output = torch::rand({16, 4}) - 0.5;
3323 
3324     auto weight = torch::rand(4);
3325     auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3326         output, target);
3327 
3328     weight = weight.expand({16, 4}).contiguous();
3329     auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3330         output, target);
3331 
3332     ASSERT_TRUE(torch::allclose(out1, out2));
3333 
3334     weight = torch::rand({16, 1});
3335     out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3336         output, target);
3337 
3338     weight = weight.expand({16, 4}).contiguous();
3339     out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3340         output, target);
3341 
3342     ASSERT_TRUE(torch::allclose(out1, out2));
3343   }
3344   { // test BCE with logits ones in pos weights are the same as none
3345     const auto target = torch::rand({64, 4});
3346     const auto output = torch::rand({64, 4}) - 0.5;
3347     const auto pos_weight = torch::ones({64, 4});
3348 
3349     ASSERT_TRUE(torch::allclose(
3350         BCEWithLogitsLoss()(output, target),
3351         BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))(
3352             output, target)));
3353   }
3354   { // test BCE with logits broadcasts pos weights
3355     const auto target = torch::rand({64, 4});
3356     const auto output = torch::rand({64, 4}) - 0.5;
3357     const auto pos_weight = torch::rand(4);
3358     const auto out1 = BCEWithLogitsLoss(
3359         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3360 
3361     const auto pos_weight1 = pos_weight.expand({1, 4});
3362     const auto out2 = BCEWithLogitsLoss(
3363         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3364 
3365     const auto pos_weight2 = pos_weight.expand({64, 4});
3366     const auto out3 = BCEWithLogitsLoss(
3367         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3368 
3369     ASSERT_TRUE(torch::allclose(out1, out2));
3370     ASSERT_TRUE(torch::allclose(out1, out3));
3371   }
3372   { // test BCE with logits with pos weight has correct grad at zero
3373     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3374     const auto target = torch::zeros({3, 1});
3375     const auto pos_weight = torch::ones({3, 1});
3376     BCEWithLogitsLoss(BCEWithLogitsLossOptions()
3377                           .pos_weight(pos_weight)
3378                           .reduction(torch::kSum))(output, target)
3379         .backward();
3380     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3381     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3382     const auto grad = output.grad();
3383     ASSERT_TRUE(torch::allclose(grad, expected_grad));
3384   }
3385   { // test BCE with logits stability
3386     const auto output = torch::tensor({0., -120.});
3387     const auto target = torch::tensor({0., 1.});
3388     const auto pos_weight = torch::tensor({1., 1.});
3389 
3390     const auto out1 = BCEWithLogitsLoss()(output, target);
3391     ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3392 
3393     const auto out2 = BCEWithLogitsLoss(
3394         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3395     ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3396   }
3397 }
3398 
3399 namespace detail {
3400 
3401 namespace F = torch::nn::functional;
3402 
_batchmatmul(const torch::Tensor & a,const torch::Tensor & b)3403 torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) {
3404   TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0));
3405   TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1));
3406   auto retval = torch::zeros(
3407       {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32);
3408   for (const auto i : c10::irange(a.size(0))) {
3409     for (const auto j : c10::irange(a.size(1))) {
3410       retval[i][j] = torch::matmul(a[i][j], b[i][j]);
3411     }
3412   }
3413   return retval;
3414 }
3415 
_softmax(const torch::Tensor & x)3416 torch::Tensor _softmax(const torch::Tensor& x) {
3417   auto output = torch::zeros(x.sizes());
3418   for (const auto i : c10::irange(x.size(0))) {
3419     for (const auto j : c10::irange(x.size(1))) {
3420       for (const auto k : c10::irange(x.size(2))) {
3421         const auto& x_curr = x[i][j][k];
3422         const auto e_x = torch::exp(x_curr - torch::max(x_curr));
3423         output[i][j][k] = e_x / torch::sum(e_x);
3424       }
3425     }
3426   }
3427   return output;
3428 }
3429 
_scaled_dot_attn_ref(const torch::Tensor & Q,const torch::Tensor & K,const torch::Tensor & V,at::IntArrayRef dims,const torch::Tensor & unseen_mask={},const torch::Tensor & key_padding_mask={},bool average_attn_weights=true)3430 std::tuple<torch::Tensor, torch::Tensor> _scaled_dot_attn_ref(
3431     const torch::Tensor& Q,
3432     const torch::Tensor& K,
3433     const torch::Tensor& V,
3434     at::IntArrayRef dims,
3435     const torch::Tensor& unseen_mask = {},
3436     const torch::Tensor& key_padding_mask = {},
3437     bool average_attn_weights = true) {
3438   auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3]));
3439   const auto b1 = QKT.size(0);
3440   const auto b2 = QKT.size(1);
3441   const auto s1 = QKT.size(2);
3442   const auto s2 = QKT.size(3);
3443   if (unseen_mask.defined() || key_padding_mask.defined()) {
3444     for (const auto i : c10::irange(b1)) {
3445       for (const auto j : c10::irange(b2)) {
3446         for (const auto m : c10::irange(s1)) {
3447           for (const auto n : c10::irange(s2)) {
3448             if (unseen_mask.defined() &&
3449                 unseen_mask[m][n].item<double>() == 0) {
3450               QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3451             }
3452             if (key_padding_mask.defined() &&
3453                 key_padding_mask[i][n].item<double>() != 0) {
3454               QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3455             }
3456           }
3457         }
3458       }
3459     }
3460   }
3461   auto reference = _softmax(QKT);
3462   auto ref_attn_weight = reference;
3463   if (average_attn_weights) {
3464     // NOLINTNEXTLINE(bugprone-argument-comment)
3465     ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2;
3466   }
3467   reference = _batchmatmul(reference, V);
3468   return std::tie(reference, ref_attn_weight);
3469 }
3470 
_split_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3471 torch::Tensor _split_heads_ref(
3472     const torch::Tensor& X,
3473     at::IntArrayRef dims,
3474     int nheads,
3475     int d_head) {
3476   auto X_split = X.reshape({dims[0], dims[1], nheads, d_head});
3477   auto X_split_transposed = X_split.permute({0, 2, 1, 3});
3478   return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head});
3479 }
3480 
_combine_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3481 torch::Tensor _combine_heads_ref(
3482     const torch::Tensor& X,
3483     at::IntArrayRef dims,
3484     int nheads,
3485     int d_head) {
3486   auto X_transposed = X.permute({0, 2, 1, 3});
3487   auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head});
3488   return reference;
3489 }
3490 
_fc(torch::Tensor X,torch::Tensor X_weight,torch::Tensor X_bias)3491 torch::Tensor _fc(
3492     torch::Tensor X,
3493     torch::Tensor X_weight,
3494     torch::Tensor X_bias) {
3495   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3496   auto X_fc_b = X_bias;
3497   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3498   auto X_fc_w = X_weight;
3499   return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b;
3500 }
3501 
_multihead_attn_test_helper(bool add_key_padding_mask=false,bool add_bias_kv=false,bool add_zero_attn=false,bool saved_kv=false,bool same_embed_dim=false,bool average_attn_weights=true)3502 void _multihead_attn_test_helper(
3503     bool add_key_padding_mask = false,
3504     bool add_bias_kv = false,
3505     bool add_zero_attn = false,
3506     bool saved_kv = false,
3507     bool same_embed_dim = false,
3508     bool average_attn_weights = true) {
3509   std::random_device device;
3510   std::mt19937 generator(device());
3511   std::uniform_int_distribution<int> d_2_10(2, 10);
3512   std::uniform_int_distribution<int> d_3_10(3, 10);
3513   bool registration_checked = false;
3514   for (const auto i : c10::irange(100)) {
3515     (void)i; // Suppress unused variable warning
3516     const auto batch_sz = d_2_10(generator);
3517     const auto seq_len = d_2_10(generator);
3518     const auto d_head = d_3_10(generator);
3519     const auto nheads = d_3_10(generator);
3520     const auto d_model = d_head * nheads;
3521     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3522     int kv_dim;
3523     if (same_embed_dim) {
3524       kv_dim = d_model;
3525     } else {
3526       std::uniform_int_distribution<int> d(5, 20);
3527       kv_dim = d(generator);
3528       while (kv_dim == d_model) {
3529         kv_dim = d(generator);
3530       }
3531     }
3532     std::vector<int64_t> dims{batch_sz, seq_len, kv_dim};
3533     torch::Tensor saved_k;
3534     torch::Tensor saved_k_tensor;
3535     torch::Tensor saved_v;
3536     torch::Tensor saved_v_tensor;
3537     if (saved_kv) {
3538       saved_k = torch::rand({batch_sz * nheads, seq_len, d_head});
3539       saved_k_tensor = saved_k;
3540       saved_v = torch::rand({batch_sz * nheads, seq_len, d_head});
3541       saved_v_tensor = saved_v;
3542     }
3543     torch::Tensor key_padding_mask;
3544     torch::Tensor key_padding_mask_tensor;
3545     if (add_key_padding_mask) {
3546       const auto seq_mask = torch::randint(0, 2, {1, seq_len});
3547       key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1;
3548       key_padding_mask_tensor = key_padding_mask;
3549     }
3550     const auto decoder_state = torch::rand({batch_sz, d_model});
3551     const torch::Tensor K = torch::rand(dims);
3552     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3553     const torch::Tensor V = K;
3554     const torch::Tensor Q =
3555         decoder_state.clone().resize_({batch_sz, 1, d_model});
3556     auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat);
3557     const torch::Tensor attn_mask_tensor = attn_mask.clone();
3558     attn_mask_tensor.masked_fill_(
3559         attn_mask_tensor == 0, -std::numeric_limits<double>::infinity());
3560     attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0));
3561 
3562     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3563     const torch::Tensor decoder_state_tensor = decoder_state;
3564     const torch::Tensor source_hid_tensor = K.transpose(0, 1);
3565 
3566     const auto options = MultiheadAttentionOptions(d_model, nheads)
3567                              .add_bias_kv(add_bias_kv)
3568                              .add_zero_attn(add_zero_attn)
3569                              .kdim(kv_dim)
3570                              .vdim(kv_dim);
3571     const auto multihead_attn_module = MultiheadAttention(options);
3572 
3573     if (!registration_checked) {
3574       // make sure parameters are all registered correctly
3575       auto named_parameters = multihead_attn_module->named_parameters();
3576       if (same_embed_dim) {
3577         ASSERT_TRUE(named_parameters.contains("in_proj_weight"));
3578       } else {
3579         ASSERT_TRUE(named_parameters.contains("q_proj_weight"));
3580         ASSERT_TRUE(named_parameters.contains("k_proj_weight"));
3581         ASSERT_TRUE(named_parameters.contains("v_proj_weight"));
3582       }
3583       if (add_bias_kv) {
3584         ASSERT_TRUE(named_parameters.contains("bias_k"));
3585         ASSERT_TRUE(named_parameters.contains("bias_v"));
3586       }
3587       // make sure sub modules are all registered correctly
3588       auto submodules = multihead_attn_module->named_children();
3589       ASSERT_TRUE(submodules.contains("out_proj"));
3590       registration_checked = true;
3591     }
3592 
3593     torch::Tensor bias_k;
3594     torch::Tensor bias_v;
3595     if (add_bias_kv) {
3596       bias_k = multihead_attn_module->bias_k.detach();
3597       bias_v = multihead_attn_module->bias_v.detach();
3598     } else {
3599       bias_k.reset();
3600       bias_v.reset();
3601     }
3602 
3603     torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1);
3604     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3605     torch::Tensor _V = source_hid_tensor;
3606     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3607     torch::Tensor _K = source_hid_tensor;
3608 
3609     torch::Tensor result;
3610     torch::Tensor result_weight;
3611     if (multihead_attn_module->_qkv_same_embed_dim) {
3612       std::tie(result, result_weight) = F::multi_head_attention_forward(
3613           _Q,
3614           _K,
3615           _V,
3616           F::MultiheadAttentionForwardFuncOptions(
3617               /*embed_dim_to_check=*/d_model,
3618               /*num_heads=*/nheads,
3619               /*in_proj_weight=*/multihead_attn_module->in_proj_weight,
3620               /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3621               /*bias_k=*/multihead_attn_module->bias_k,
3622               /*bias_v=*/multihead_attn_module->bias_v,
3623               /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3624               /*dropout_p=*/multihead_attn_module->options.dropout(),
3625               /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3626               /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3627               .training(multihead_attn_module->is_training())
3628               .key_padding_mask(key_padding_mask_tensor)
3629               .need_weights(true)
3630               .attn_mask(attn_mask_tensor)
3631               .static_k(saved_k_tensor)
3632               .static_v(saved_v_tensor)
3633               .average_attn_weights(average_attn_weights));
3634     } else {
3635       std::tie(result, result_weight) = F::multi_head_attention_forward(
3636           _Q,
3637           _K,
3638           _V,
3639           F::MultiheadAttentionForwardFuncOptions(
3640               /*embed_dim_to_check=*/d_model,
3641               /*num_heads=*/nheads,
3642               /*in_proj_weight=*/{},
3643               /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3644               /*bias_k=*/multihead_attn_module->bias_k,
3645               /*bias_v=*/multihead_attn_module->bias_v,
3646               /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3647               /*dropout_p=*/multihead_attn_module->options.dropout(),
3648               /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3649               /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3650               .training(multihead_attn_module->is_training())
3651               .key_padding_mask(key_padding_mask_tensor)
3652               .need_weights(true)
3653               .attn_mask(attn_mask_tensor)
3654               .use_separate_proj_weight(true)
3655               .q_proj_weight(multihead_attn_module->q_proj_weight)
3656               .k_proj_weight(multihead_attn_module->k_proj_weight)
3657               .v_proj_weight(multihead_attn_module->v_proj_weight)
3658               .static_k(saved_k_tensor)
3659               .static_v(saved_v_tensor)
3660               .average_attn_weights(average_attn_weights));
3661     }
3662     result = result.squeeze(0).detach();
3663     torch::Tensor q_proj_weight;
3664     torch::Tensor k_proj_weight;
3665     torch::Tensor v_proj_weight;
3666     if (multihead_attn_module->_qkv_same_embed_dim) {
3667       q_proj_weight =
3668           multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model);
3669       k_proj_weight = multihead_attn_module->in_proj_weight.slice(
3670           /*dim=*/0, d_model, (d_model * 2));
3671       v_proj_weight =
3672           multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2));
3673     } else {
3674       q_proj_weight = multihead_attn_module->q_proj_weight;
3675       k_proj_weight = multihead_attn_module->k_proj_weight;
3676       v_proj_weight = multihead_attn_module->v_proj_weight;
3677     }
3678     auto Q_fc =
3679         _fc(Q,
3680             q_proj_weight,
3681             multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model));
3682     auto K_fc =
3683         _fc(K,
3684             k_proj_weight,
3685             multihead_attn_module->in_proj_bias.slice(
3686                 /*dim=*/0, d_model, (d_model * 2)));
3687     auto V_fc = _fc(
3688         V,
3689         v_proj_weight,
3690         multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2)));
3691 
3692     if (add_bias_kv) {
3693       K_fc = torch::cat(
3694           {K_fc,
3695            bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)},
3696           /*dim=*/1);
3697       V_fc = torch::cat(
3698           {V_fc,
3699            bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)},
3700           /*dim=*/1);
3701       if (attn_mask.defined()) {
3702         attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3703       }
3704       if (key_padding_mask.defined()) {
3705         key_padding_mask = torch::cat(
3706             {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3707             /*dim=*/1);
3708       }
3709       dims[1] += 1;
3710     }
3711     const auto Q_split =
3712         _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head);
3713     torch::Tensor K_split;
3714     if (saved_k.defined()) {
3715       K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head});
3716     } else {
3717       K_split = _split_heads_ref(K_fc, dims, nheads, d_head);
3718     }
3719     torch::Tensor V_split;
3720     if (saved_v.defined()) {
3721       V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head});
3722     } else {
3723       V_split = _split_heads_ref(V_fc, dims, nheads, d_head);
3724     }
3725     if (add_zero_attn) {
3726       dims[1] += 1;
3727       K_split = torch::cat(
3728           {K_split,
3729            torch::zeros(
3730                {K_split.size(0), K_split.size(1), 1, K_split.size(3)})},
3731           /*dim=*/2);
3732       V_split = torch::cat(
3733           {V_split,
3734            torch::zeros(
3735                {V_split.size(0), V_split.size(1), 1, V_split.size(3)})},
3736           /*dim=*/2);
3737       if (attn_mask.defined()) {
3738         attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3739       }
3740       if (key_padding_mask.defined()) {
3741         key_padding_mask = torch::cat(
3742             {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3743             /*dim=*/1);
3744       }
3745     }
3746     auto [attn_heads, ref_attn_weight] = _scaled_dot_attn_ref(
3747         Q_split,
3748         K_split,
3749         V_split,
3750         Q_split.sizes(),
3751         attn_mask,
3752         key_padding_mask,
3753         average_attn_weights);
3754     const auto combined_attn_heads =
3755         _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head);
3756     auto reference =
3757         _fc(combined_attn_heads,
3758             multihead_attn_module->out_proj->weight,
3759             multihead_attn_module->out_proj->bias);
3760     // NOLINTNEXTLINE(bugprone-argument-comment)
3761     reference = torch::squeeze(reference, /*axis=*/1);
3762 
3763     // result = reference
3764     ASSERT_EQ(result.sizes(), std::vector<int64_t>({batch_sz, d_model}));
3765     ASSERT_TRUE(
3766         torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true));
3767 
3768     // result_weight = ref_attn_weight
3769     result_weight = result_weight.detach();
3770     ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes());
3771     ASSERT_TRUE(torch::allclose(
3772         result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true));
3773   }
3774 }
3775 } // namespace detail
3776 
TEST_F(ModulesTest,MultiheadAttention)3777 TEST_F(ModulesTest, MultiheadAttention) {
3778   using namespace ::detail;
3779 
3780   for (auto average_attn_weights : {false, true}) {
3781     // test_multihead_attn_add_zero_attn
3782     _multihead_attn_test_helper(
3783         /*add_key_padding_mask=*/false,
3784         /*add_bias_kv=*/false,
3785         /*add_zero_attn=*/true,
3786         /*saved_kv=*/false,
3787         /*same_embed_dim=*/false,
3788         /*average_attn_weights=*/average_attn_weights);
3789 
3790     // test_multihead_attn_add_bias_kv
3791     _multihead_attn_test_helper(
3792         /*add_key_padding_mask=*/false,
3793         /*add_bias_kv=*/true,
3794         /*add_zero_attn=*/false,
3795         /*saved_kv=*/false,
3796         /*same_embed_dim=*/false,
3797         /*average_attn_weights=*/average_attn_weights);
3798 
3799     // test_multihead_attn_no_masking():
3800     _multihead_attn_test_helper();
3801 
3802     // test_multihead_attn_key_padding_mask
3803     _multihead_attn_test_helper(
3804         /*add_key_padding_mask=*/true,
3805         /*add_bias_kv=*/false,
3806         /*add_zero_attn=*/false,
3807         /*saved_kv=*/false,
3808         /*same_embed_dim=*/false,
3809         /*average_attn_weights=*/average_attn_weights);
3810 
3811     // test_multihead_attn_saved_kv
3812     _multihead_attn_test_helper(
3813         /*add_key_padding_mask=*/false,
3814         /*add_bias_kv=*/false,
3815         /*add_zero_attn=*/false,
3816         /*saved_kv=*/true,
3817         /*same_embed_dim=*/false,
3818         /*average_attn_weights=*/average_attn_weights);
3819 
3820     // test_multihead_attn_add_bias_kv_zero_attn
3821     _multihead_attn_test_helper(
3822         /*add_key_padding_mask=*/true,
3823         /*add_bias_kv=*/true,
3824         /*add_zero_attn=*/true,
3825         /*saved_kv=*/false,
3826         /*same_embed_dim=*/false,
3827         /*average_attn_weights=*/average_attn_weights);
3828 
3829     // test_multihead_attn_all_arguments1
3830     _multihead_attn_test_helper(
3831         /*add_key_padding_mask=*/true,
3832         /*add_bias_kv=*/false,
3833         /*add_zero_attn=*/true,
3834         /*saved_kv=*/true,
3835         /*same_embed_dim=*/false,
3836         /*average_attn_weights=*/average_attn_weights);
3837 
3838     ASSERT_THROWS_WITH(
3839         // test_multihead_attn_all_arguments2
3840         _multihead_attn_test_helper(
3841             /*add_key_padding_mask=*/true,
3842             /*add_bias_kv=*/true,
3843             /*add_zero_attn=*/true,
3844             /*saved_kv=*/true,
3845             /*same_embed_dim=*/false,
3846             /*average_attn_weights=*/average_attn_weights),
3847         "bias cannot be added to static key");
3848 
3849     // test_multihead_attn_all_arguments3
3850     _multihead_attn_test_helper(
3851         /*add_key_padding_mask=*/true,
3852         /*add_bias_kv=*/false,
3853         /*add_zero_attn=*/true,
3854         /*saved_kv=*/true,
3855         /*same_embed_dim=*/true,
3856         /*average_attn_weights=*/average_attn_weights);
3857   }
3858 }
3859 
TEST_F(ModulesTest,PrettyPrintIdentity)3860 TEST_F(ModulesTest, PrettyPrintIdentity) {
3861   ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
3862 }
3863 
TEST_F(ModulesTest,PrettyPrintFlatten)3864 TEST_F(ModulesTest, PrettyPrintFlatten) {
3865   ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)");
3866   ASSERT_EQ(
3867       c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
3868       "torch::nn::Flatten(start_dim=2, end_dim=4)");
3869 }
3870 
TEST_F(ModulesTest,PrettyPrintUnflatten)3871 TEST_F(ModulesTest, PrettyPrintUnflatten) {
3872   ASSERT_EQ(
3873       c10::str(Unflatten(UnflattenOptions(0, {2, 2}))),
3874       "torch::nn::Unflatten(dim=0, unflattened_size={2, 2})");
3875   ASSERT_EQ(
3876       c10::str(Unflatten(UnflattenOptions(
3877           "B",
3878           {std::pair<std::string, int64_t>{"B1", 2},
3879            std::pair<std::string, int64_t>{"B2", 2}}))),
3880       "torch::nn::Unflatten(dim=\"B\", unflattened_size={{\"B1\", 2}, {\"B2\", 2}})");
3881 }
3882 
TEST_F(ModulesTest,ReflectionPad1d)3883 TEST_F(ModulesTest, ReflectionPad1d) {
3884   {
3885     ReflectionPad1d m(ReflectionPad1dOptions(2));
3886     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3887     auto output = m(input);
3888     auto expected = torch::tensor(
3889         {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}},
3890         torch::kFloat);
3891     ASSERT_TRUE(output.allclose(expected));
3892   }
3893   {
3894     ReflectionPad1d m(ReflectionPad1dOptions({3, 1}));
3895     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3896     auto output = m(input);
3897     auto expected = torch::tensor(
3898         {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}},
3899         torch::kFloat);
3900     ASSERT_TRUE(output.allclose(expected));
3901   }
3902 }
3903 
TEST_F(ModulesTest,ReflectionPad2d)3904 TEST_F(ModulesTest, ReflectionPad2d) {
3905   {
3906     ReflectionPad2d m(ReflectionPad2dOptions(2));
3907     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3908     auto output = m(input);
3909     auto expected = torch::tensor(
3910         {{{{8., 7., 6., 7., 8., 7., 6.},
3911            {5., 4., 3., 4., 5., 4., 3.},
3912            {2., 1., 0., 1., 2., 1., 0.},
3913            {5., 4., 3., 4., 5., 4., 3.},
3914            {8., 7., 6., 7., 8., 7., 6.},
3915            {5., 4., 3., 4., 5., 4., 3.},
3916            {2., 1., 0., 1., 2., 1., 0.}}}},
3917         torch::kFloat);
3918     ASSERT_TRUE(output.allclose(expected));
3919   }
3920   {
3921     ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0}));
3922     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3923     auto output = m(input);
3924     auto expected = torch::tensor(
3925         {{{{7., 6., 7., 8., 7.},
3926            {4., 3., 4., 5., 4.},
3927            {1., 0., 1., 2., 1.},
3928            {4., 3., 4., 5., 4.},
3929            {7., 6., 7., 8., 7.}}}},
3930         torch::kFloat);
3931     ASSERT_TRUE(output.allclose(expected));
3932   }
3933 }
3934 
TEST_F(ModulesTest,ReflectionPad3d)3935 TEST_F(ModulesTest, ReflectionPad3d) {
3936   {
3937     ReflectionPad3d m(ReflectionPad3dOptions(1));
3938     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
3939     auto output = m(input);
3940     auto expected = torch::tensor(
3941         {{{{{7., 6., 7., 6.},
3942             {5., 4., 5., 4.},
3943             {7., 6., 7., 6.},
3944             {5., 4., 5., 4.}},
3945            {{3., 2., 3., 2.},
3946             {1., 0., 1., 0.},
3947             {3., 2., 3., 2.},
3948             {1., 0., 1., 0.}},
3949            {{7., 6., 7., 6.},
3950             {5., 4., 5., 4.},
3951             {7., 6., 7., 6.},
3952             {5., 4., 5., 4.}},
3953            {{3., 2., 3., 2.},
3954             {1., 0., 1., 0.},
3955             {3., 2., 3., 2.},
3956             {1., 0., 1., 0.}}}}},
3957         torch::kFloat);
3958     ASSERT_TRUE(output.allclose(expected));
3959   }
3960   {
3961     ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2}));
3962     auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2});
3963     auto output = m(input);
3964     auto expected = torch::tensor(
3965         {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3966            {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}},
3967            {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3968            {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3969            {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}},
3970            {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3971            {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}},
3972         torch::kFloat);
3973     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 3, 3}));
3974     ASSERT_TRUE(output.allclose(expected));
3975   }
3976 }
TEST_F(ModulesTest,ReplicationPad1d)3977 TEST_F(ModulesTest, ReplicationPad1d) {
3978   {
3979     ReplicationPad1d m(ReplicationPad1dOptions(2));
3980     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3981     auto output = m(input);
3982     auto expected = torch::tensor(
3983         {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}},
3984         torch::kFloat);
3985     ASSERT_TRUE(output.allclose(expected));
3986   }
3987   {
3988     ReplicationPad1d m(ReplicationPad1dOptions({3, 1}));
3989     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3990     auto output = m(input);
3991     auto expected = torch::tensor(
3992         {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}},
3993         torch::kFloat);
3994     ASSERT_TRUE(output.allclose(expected));
3995   }
3996 }
3997 
TEST_F(ModulesTest,ReplicationPad2d)3998 TEST_F(ModulesTest, ReplicationPad2d) {
3999   {
4000     ReplicationPad2d m(ReplicationPad2dOptions(2));
4001     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4002     auto output = m(input);
4003     auto expected = torch::tensor(
4004         {{{{0., 0., 0., 1., 2., 2., 2.},
4005            {0., 0., 0., 1., 2., 2., 2.},
4006            {0., 0., 0., 1., 2., 2., 2.},
4007            {3., 3., 3., 4., 5., 5., 5.},
4008            {6., 6., 6., 7., 8., 8., 8.},
4009            {6., 6., 6., 7., 8., 8., 8.},
4010            {6., 6., 6., 7., 8., 8., 8.}}}},
4011         torch::kFloat);
4012     ASSERT_TRUE(output.allclose(expected));
4013   }
4014   {
4015     ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0}));
4016     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4017     auto output = m(input);
4018     auto expected = torch::tensor(
4019         {{{{0., 0., 1., 2., 2.},
4020            {0., 0., 1., 2., 2.},
4021            {0., 0., 1., 2., 2.},
4022            {3., 3., 4., 5., 5.},
4023            {6., 6., 7., 8., 8.}}}},
4024         torch::kFloat);
4025     ASSERT_TRUE(output.allclose(expected));
4026   }
4027 }
4028 
TEST_F(ModulesTest,ReplicationPad3d)4029 TEST_F(ModulesTest, ReplicationPad3d) {
4030   {
4031     ReplicationPad3d m(ReplicationPad3dOptions(1));
4032     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4033     auto output = m(input);
4034     auto expected = torch::tensor(
4035         {{{{{0., 0., 1., 1.},
4036             {0., 0., 1., 1.},
4037             {2., 2., 3., 3.},
4038             {2., 2., 3., 3.}},
4039            {{0., 0., 1., 1.},
4040             {0., 0., 1., 1.},
4041             {2., 2., 3., 3.},
4042             {2., 2., 3., 3.}},
4043            {{4., 4., 5., 5.},
4044             {4., 4., 5., 5.},
4045             {6., 6., 7., 7.},
4046             {6., 6., 7., 7.}},
4047            {{4., 4., 5., 5.},
4048             {4., 4., 5., 5.},
4049             {6., 6., 7., 7.},
4050             {6., 6., 7., 7.}}}}},
4051         torch::kFloat);
4052     ASSERT_TRUE(output.allclose(expected));
4053   }
4054   {
4055     ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}));
4056     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4057     auto output = m(input);
4058     auto expected = torch::tensor(
4059         {{{{{0., 0., 1., 1., 1.},
4060             {0., 0., 1., 1., 1.},
4061             {2., 2., 3., 3., 3.},
4062             {2., 2., 3., 3., 3.},
4063             {2., 2., 3., 3., 3.}},
4064            {{0., 0., 1., 1., 1.},
4065             {0., 0., 1., 1., 1.},
4066             {2., 2., 3., 3., 3.},
4067             {2., 2., 3., 3., 3.},
4068             {2., 2., 3., 3., 3.}},
4069            {{4., 4., 5., 5., 5.},
4070             {4., 4., 5., 5., 5.},
4071             {6., 6., 7., 7., 7.},
4072             {6., 6., 7., 7., 7.},
4073             {6., 6., 7., 7., 7.}},
4074            {{4., 4., 5., 5., 5.},
4075             {4., 4., 5., 5., 5.},
4076             {6., 6., 7., 7., 7.},
4077             {6., 6., 7., 7., 7.},
4078             {6., 6., 7., 7., 7.}},
4079            {{4., 4., 5., 5., 5.},
4080             {4., 4., 5., 5., 5.},
4081             {6., 6., 7., 7., 7.},
4082             {6., 6., 7., 7., 7.},
4083             {6., 6., 7., 7., 7.}}}}},
4084         torch::kFloat);
4085     ASSERT_TRUE(output.allclose(expected));
4086   }
4087 }
4088 
TEST_F(ModulesTest,ZeroPad1d)4089 TEST_F(ModulesTest, ZeroPad1d) {
4090   {
4091     ZeroPad1d m(ZeroPad1dOptions(2));
4092     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4093     auto output = m(input);
4094     auto expected = torch::tensor(
4095         {{{0., 0., 0., 1., 2., 3., 0., 0.}, {0., 0., 4., 5., 6., 7., 0., 0.}}},
4096         torch::kFloat);
4097     ASSERT_TRUE(output.allclose(expected));
4098   }
4099   {
4100     ZeroPad1d m(ZeroPad1dOptions({3, 1}));
4101     auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4102     auto output = m(input);
4103     auto expected = torch::tensor(
4104         {{{0., 0., 0., 0., 1., 2., 0.}, {0., 0., 0., 3., 4., 5., 0.}}},
4105         torch::kFloat);
4106     ASSERT_TRUE(output.allclose(expected));
4107   }
4108 }
4109 
TEST_F(ModulesTest,ZeroPad2d)4110 TEST_F(ModulesTest, ZeroPad2d) {
4111   {
4112     ZeroPad2d m(ZeroPad2dOptions(2));
4113     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4114     auto output = m(input);
4115     auto expected = torch::tensor(
4116         {{{{0., 0., 0., 0., 0., 0., 0.},
4117            {0., 0., 0., 0., 0., 0., 0.},
4118            {0., 0., 0., 1., 2., 0., 0.},
4119            {0., 0., 3., 4., 5., 0., 0.},
4120            {0., 0., 6., 7., 8., 0., 0.},
4121            {0., 0., 0., 0., 0., 0., 0.},
4122            {0., 0., 0., 0., 0., 0., 0.}}}},
4123         torch::kFloat);
4124     ASSERT_TRUE(output.allclose(expected));
4125   }
4126   {
4127     ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0}));
4128     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4129     auto output = m(input);
4130     auto expected = torch::tensor(
4131         {{{{0., 0., 0., 0., 0.},
4132            {0., 0., 0., 0., 0.},
4133            {0., 0., 1., 2., 0.},
4134            {0., 3., 4., 5., 0.},
4135            {0., 6., 7., 8., 0.}}}},
4136         torch::kFloat);
4137     ASSERT_TRUE(output.allclose(expected));
4138   }
4139 }
4140 
TEST_F(ModulesTest,ZeroPad3d)4141 TEST_F(ModulesTest, ZeroPad3d) {
4142   {
4143     ZeroPad3d m(ZeroPad3dOptions(1));
4144     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4145     auto output = m(input);
4146     auto expected = torch::tensor(
4147         {{{{{0., 0., 0., 0.},
4148             {0., 0., 0., 0.},
4149             {0., 0., 0., 0.},
4150             {0., 0., 0., 0.}},
4151            {{0., 0., 0., 0.},
4152             {0., 0., 1., 0.},
4153             {0., 2., 3., 0.},
4154             {0., 0., 0., 0.}},
4155            {{0., 0., 0., 0.},
4156             {0., 4., 5., 0.},
4157             {0., 6., 7., 0.},
4158             {0., 0., 0., 0.}},
4159            {{0., 0., 0., 0.},
4160             {0., 0., 0., 0.},
4161             {0., 0., 0., 0.},
4162             {0., 0., 0., 0.}}}}},
4163         torch::kFloat);
4164     ASSERT_TRUE(output.allclose(expected));
4165   }
4166   {
4167     ZeroPad3d m(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}));
4168     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4169     auto output = m(input);
4170     auto expected = torch::tensor(
4171         {{{{{0., 0., 0., 0., 0.},
4172             {0., 0., 0., 0., 0.},
4173             {0., 0., 0., 0., 0.},
4174             {0., 0., 0., 0., 0.},
4175             {0., 0., 0., 0., 0.}},
4176            {{0., 0., 0., 0., 0.},
4177             {0., 0., 1., 0., 0.},
4178             {0., 2., 3., 0., 0.},
4179             {0., 0., 0., 0., 0.},
4180             {0., 0., 0., 0., 0.}},
4181            {{0., 0., 0., 0., 0.},
4182             {0., 4., 5., 0., 0.},
4183             {0., 6., 7., 0., 0.},
4184             {0., 0., 0., 0., 0.},
4185             {0., 0., 0., 0., 0.}},
4186            {{0., 0., 0., 0., 0.},
4187             {0., 0., 0., 0., 0.},
4188             {0., 0., 0., 0., 0.},
4189             {0., 0., 0., 0., 0.},
4190             {0., 0., 0., 0., 0.}},
4191            {{0., 0., 0., 0., 0.},
4192             {0., 0., 0., 0., 0.},
4193             {0., 0., 0., 0., 0.},
4194             {0., 0., 0., 0., 0.},
4195             {0., 0., 0., 0., 0.}}}}},
4196         torch::kFloat);
4197     ASSERT_TRUE(output.allclose(expected));
4198   }
4199 }
4200 
TEST_F(ModulesTest,ConstantPad1d)4201 TEST_F(ModulesTest, ConstantPad1d) {
4202   {
4203     ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
4204     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4205     auto output = m(input);
4206     auto expected = torch::tensor(
4207         {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000},
4208           {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}},
4209         torch::kFloat);
4210     ASSERT_TRUE(output.allclose(expected));
4211   }
4212   {
4213     ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5));
4214     auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4215     auto output = m(input);
4216     auto expected = torch::tensor(
4217         {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000},
4218           {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}},
4219         torch::kFloat);
4220     ASSERT_TRUE(output.allclose(expected));
4221   }
4222 }
4223 
TEST_F(ModulesTest,ConstantPad2d)4224 TEST_F(ModulesTest, ConstantPad2d) {
4225   {
4226     ConstantPad2d m(ConstantPad2dOptions(2, 3.5));
4227     auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4228     auto output = m(input);
4229     auto expected = torch::tensor(
4230         {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4231           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4232           {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4233           {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4234           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4235           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4236         torch::kFloat);
4237     ASSERT_TRUE(output.allclose(expected));
4238   }
4239   {
4240     ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5));
4241     auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4242     auto output = m(input);
4243     auto expected = torch::tensor(
4244         {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4245           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4246           {3.5000, 3.5000, 3.5000, 0.0000, 1.0000},
4247           {3.5000, 3.5000, 3.5000, 2.0000, 3.0000},
4248           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4249         torch::kFloat);
4250     ASSERT_TRUE(output.allclose(expected));
4251   }
4252 }
4253 
TEST_F(ModulesTest,ConstantPad3d)4254 TEST_F(ModulesTest, ConstantPad3d) {
4255   {
4256     ConstantPad3d m(ConstantPad3dOptions(1, 3.5));
4257     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4258     auto output = m(input);
4259     auto expected = torch::tensor(
4260         {{{{{3.5000, 3.5000, 3.5000, 3.5000},
4261             {3.5000, 3.5000, 3.5000, 3.5000},
4262             {3.5000, 3.5000, 3.5000, 3.5000},
4263             {3.5000, 3.5000, 3.5000, 3.5000}},
4264            {{3.5000, 3.5000, 3.5000, 3.5000},
4265             {3.5000, 0.0000, 1.0000, 3.5000},
4266             {3.5000, 2.0000, 3.0000, 3.5000},
4267             {3.5000, 3.5000, 3.5000, 3.5000}},
4268            {{3.5000, 3.5000, 3.5000, 3.5000},
4269             {3.5000, 4.0000, 5.0000, 3.5000},
4270             {3.5000, 6.0000, 7.0000, 3.5000},
4271             {3.5000, 3.5000, 3.5000, 3.5000}},
4272            {{3.5000, 3.5000, 3.5000, 3.5000},
4273             {3.5000, 3.5000, 3.5000, 3.5000},
4274             {3.5000, 3.5000, 3.5000, 3.5000},
4275             {3.5000, 3.5000, 3.5000, 3.5000}}}}},
4276         torch::kFloat);
4277     ASSERT_TRUE(output.allclose(expected));
4278   }
4279   {
4280     ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5));
4281     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4282     auto output = m(input);
4283     auto expected = torch::tensor(
4284         {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4285             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4286             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4287             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4288             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4289            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4290             {3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4291             {3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4292             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4293             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4294            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4295             {3.5000, 4.0000, 5.0000, 3.5000, 3.5000},
4296             {3.5000, 6.0000, 7.0000, 3.5000, 3.5000},
4297             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4298             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4299            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4300             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4301             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4302             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4303             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4304            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4305             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4306             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4307             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4308             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}},
4309         torch::kFloat);
4310     ASSERT_TRUE(output.allclose(expected));
4311   }
4312 }
4313 
TEST_F(ModulesTest,CrossMapLRN2d)4314 TEST_F(ModulesTest, CrossMapLRN2d) {
4315   /// size 3, default options
4316   auto input =
4317       torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true);
4318   auto expected = torch::tensor(
4319       {{{{0.00000000, 0.99997497, 1.99980010},
4320          {2.99932500, 3.99840070, 4.99687700},
4321          {5.99460600, 6.99143740, 7.98722360}}}},
4322       torch::kFloat32);
4323   auto grad_expected = torch::tensor(
4324       {{{{1.00000000, 0.99992496, 0.99970007},
4325          {0.99932520, 0.99880093, 0.99812720},
4326          {0.99730474, 0.99633380, 0.99521490}}}},
4327       torch::kFloat32);
4328   auto crossmaplrn2d = CrossMapLRN2d(3);
4329   auto output = crossmaplrn2d(input);
4330   output.sum().backward();
4331 
4332   ASSERT_TRUE(input.grad().allclose(grad_expected));
4333   ASSERT_TRUE(output.allclose(expected));
4334 
4335   /// size change
4336   crossmaplrn2d =
4337       CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1));
4338   output = crossmaplrn2d(input);
4339   expected = torch::tensor(
4340       {{{{0.00000000, 0.99998120, 1.99985000},
4341          {2.99949400, 3.99880050, 4.99765800},
4342          {5.99595300, 6.99357600, 7.99041300}}}},
4343       torch::kFloat32);
4344   ASSERT_TRUE(output.allclose(expected));
4345 
4346   /// alpha change
4347   crossmaplrn2d =
4348       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1));
4349   output = crossmaplrn2d(input);
4350   expected = torch::tensor(
4351       {{{{0.00000000, 0.99975010, 1.99800230},
4352          {2.99326750, 3.98407440, 4.96897600},
4353          {5.94656100, 6.91545720, 7.87434340}}}},
4354       torch::kFloat32);
4355   ASSERT_TRUE(output.allclose(expected));
4356 
4357   /// beta change
4358   crossmaplrn2d =
4359       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1));
4360   output = crossmaplrn2d(input);
4361   expected = torch::tensor(
4362       {{{{0.00000000, 0.99996830, 1.99974680},
4363          {2.99914500, 3.99797440, 4.99604460},
4364          {5.99316840, 6.98915600, 7.98382000}}}},
4365       torch::kFloat32);
4366   ASSERT_TRUE(output.allclose(expected));
4367 
4368   /// k change
4369   crossmaplrn2d =
4370       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2));
4371   output = crossmaplrn2d(input);
4372   expected = torch::tensor(
4373       {{{{0.00000000, 0.59459610, 1.18914770},
4374          {1.78361000, 2.37793870, 2.97208900},
4375          {3.56601700, 4.15967700, 4.75302650}}}},
4376       torch::kFloat32);
4377   ASSERT_TRUE(output.allclose(expected));
4378 }
4379 
TEST_F(ModulesTest,RNNCell)4380 TEST_F(ModulesTest, RNNCell) {
4381   torch::manual_seed(0);
4382   auto rnn = RNNCell(1, 2);
4383 
4384   auto input = torch::randn({3, 1});
4385   auto hx = torch::randn({3, 2});
4386   auto output = rnn(input, hx);
4387   auto expected =
4388       torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}});
4389   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4390 
4391   output = rnn(input);
4392   expected =
4393       torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}});
4394   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4395 
4396   input = torch::randn({1});
4397   hx = torch::randn({2});
4398   output = rnn(input, hx);
4399   expected = torch::tensor({0.2808, 0.6505});
4400   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4401 
4402   {
4403     auto input = torch::randn({3, 2});
4404     auto hx = torch::randn({3, 2});
4405     ASSERT_THROWS_WITH(
4406         rnn(input, hx), "input has inconsistent input_size: got 2 expected 1");
4407   }
4408 
4409   {
4410     auto input = torch::randn({3, 1});
4411     auto hx = torch::randn({3, 1});
4412     ASSERT_THROWS_WITH(
4413         rnn(input, hx),
4414         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4415   }
4416 
4417   {
4418     auto input = torch::randn({3, 1, 1, 1, 1});
4419     auto hx = torch::randn({3, 2});
4420     ASSERT_THROWS_WITH(
4421         rnn(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4422   }
4423 
4424   {
4425     auto input = torch::randn({3, 1});
4426     auto hx = torch::randn({3, 1, 1, 1, 2});
4427     ASSERT_THROWS_WITH(
4428         rnn(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4429   }
4430 }
4431 
TEST_F(ModulesTest,LSTMCell)4432 TEST_F(ModulesTest, LSTMCell) {
4433   torch::manual_seed(0);
4434   auto lstm = LSTMCell(1, 2);
4435 
4436   auto input = torch::randn({3, 1});
4437   auto hx = torch::randn({3, 2});
4438   auto cx = torch::randn({3, 2});
4439   auto output = lstm(input, std::make_tuple(hx, cx));
4440   auto output_hx = std::get<0>(output);
4441   auto output_cx = std::get<1>(output);
4442   auto expected_hx =
4443       torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}});
4444   auto expected_cx =
4445       torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}});
4446   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4447   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4448 
4449   output = lstm(input);
4450   output_hx = std::get<0>(output);
4451   output_cx = std::get<1>(output);
4452   expected_hx =
4453       torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}});
4454   expected_cx =
4455       torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}});
4456   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4457   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4458 
4459   input = torch::randn({1});
4460   hx = torch::randn({2});
4461   cx = torch::randn({2});
4462   output = lstm(input, std::make_tuple(hx, cx));
4463   output_hx = std::get<0>(output);
4464   output_cx = std::get<1>(output);
4465   expected_hx = torch::tensor({-0.0443, 0.1537});
4466   expected_cx = torch::tensor({-0.1195, 0.2144});
4467   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4468   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4469 
4470   {
4471     auto input = torch::randn({3, 2});
4472     auto hx = torch::randn({3, 2});
4473     auto cx = torch::randn({3, 2});
4474     ASSERT_THROWS_WITH(
4475         lstm(input, std::make_tuple(hx, cx)),
4476         "input has inconsistent input_size: got 2 expected 1");
4477   }
4478 
4479   {
4480     auto input = torch::randn({3, 1});
4481     auto hx = torch::randn({3, 1});
4482     auto cx = torch::randn({3, 2});
4483     ASSERT_THROWS_WITH(
4484         lstm(input, std::make_tuple(hx, cx)),
4485         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4486   }
4487 
4488   {
4489     auto input = torch::randn({3, 1});
4490     auto hx = torch::randn({3, 2});
4491     auto cx = torch::randn({3, 1});
4492     ASSERT_THROWS_WITH(
4493         lstm(input, std::make_tuple(hx, cx)),
4494         "hidden1 has inconsistent hidden_size: got 1, expected 2");
4495   }
4496 
4497   {
4498     auto input = torch::randn({3, 1, 1, 1, 1});
4499     auto hx = torch::randn({3, 1});
4500     auto cx = torch::randn({3, 1});
4501     ASSERT_THROWS_WITH(
4502         lstm(input, std::make_tuple(hx, cx)),
4503         "Expected input to be 1D or 2D, got 5D instead");
4504   }
4505 
4506   {
4507     auto input = torch::randn({3, 1});
4508     auto hx = torch::randn({3, 1, 1, 1, 2});
4509     auto cx = torch::randn({3, 2});
4510     ASSERT_THROWS_WITH(
4511         lstm(input, std::make_tuple(hx, cx)),
4512         "Expected hx[0] to be 1D or 2D, got 5D instead");
4513   }
4514 
4515   {
4516     auto input = torch::randn({3, 1});
4517     auto hx = torch::randn({3, 2});
4518     auto cx = torch::randn({3, 1, 1, 1, 2});
4519     ASSERT_THROWS_WITH(
4520         lstm(input, std::make_tuple(hx, cx)),
4521         "Expected hx[1] to be 1D or 2D, got 5D instead");
4522   }
4523 }
4524 
TEST_F(ModulesTest,GRUCell)4525 TEST_F(ModulesTest, GRUCell) {
4526   torch::manual_seed(0);
4527   auto gru = GRUCell(1, 2);
4528 
4529   auto input = torch::randn({3, 1});
4530   auto hx = torch::randn({3, 2});
4531   auto output = gru(input, hx);
4532   auto expected =
4533       torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}});
4534   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4535 
4536   output = gru(input);
4537   expected =
4538       torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}});
4539   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4540 
4541   input = torch::randn({1});
4542   hx = torch::randn({2});
4543   output = gru(input, hx);
4544   expected = torch::tensor({-1.0058, -0.3025});
4545   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4546 
4547   {
4548     auto input = torch::randn({3, 2});
4549     auto hx = torch::randn({3, 2});
4550     ASSERT_THROWS_WITH(
4551         gru(input, hx), "input has inconsistent input_size: got 2 expected 1");
4552   }
4553 
4554   {
4555     auto input = torch::randn({3, 1});
4556     auto hx = torch::randn({3, 1});
4557     ASSERT_THROWS_WITH(
4558         gru(input, hx),
4559         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4560   }
4561 
4562   {
4563     auto input = torch::randn({3, 1, 1, 1, 1});
4564     auto hx = torch::randn({3, 2});
4565     ASSERT_THROWS_WITH(
4566         gru(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4567   }
4568 
4569   {
4570     auto input = torch::randn({3, 1});
4571     auto hx = torch::randn({3, 1, 1, 1, 2});
4572     ASSERT_THROWS_WITH(
4573         gru(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4574   }
4575 }
4576 
TEST_F(ModulesTest,PrettyPrintLinear)4577 TEST_F(ModulesTest, PrettyPrintLinear) {
4578   ASSERT_EQ(
4579       c10::str(Linear(3, 4)),
4580       "torch::nn::Linear(in_features=3, out_features=4, bias=true)");
4581 }
4582 
TEST_F(ModulesTest,PrettyPrintBilinear)4583 TEST_F(ModulesTest, PrettyPrintBilinear) {
4584   ASSERT_EQ(
4585       c10::str(Bilinear(3, 2, 4)),
4586       "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)");
4587   ASSERT_EQ(
4588       c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))),
4589       "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)");
4590 }
4591 
TEST_F(ModulesTest,PrettyPrintConv)4592 TEST_F(ModulesTest, PrettyPrintConv) {
4593   ASSERT_EQ(
4594       c10::str(Conv1d(3, 4, 5)),
4595       "torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
4596 
4597   ASSERT_EQ(
4598       c10::str(Conv2d(3, 4, 5)),
4599       "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4600   ASSERT_EQ(
4601       c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
4602       "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4603   {
4604     const auto options =
4605         Conv2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4606     ASSERT_EQ(
4607         c10::str(Conv2d(options)),
4608         "torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4609   }
4610 
4611   ASSERT_EQ(
4612       c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4613       "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4614   {
4615     const auto options = Conv3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4616                              .stride({1, 2, 3})
4617                              .padding(1)
4618                              .dilation(0)
4619                              .groups(2)
4620                              .bias(false)
4621                              .padding_mode(torch::kCircular);
4622     ASSERT_EQ(
4623         c10::str(Conv3d(options)),
4624         "torch::nn::Conv3d("
4625         "4, "
4626         "4, "
4627         "kernel_size=[5, 6, 7], "
4628         "stride=[1, 2, 3], "
4629         "padding=[1, 1, 1], "
4630         "dilation=[0, 0, 0], "
4631         "groups=2, "
4632         "bias=false, "
4633         "padding_mode=kCircular)");
4634   }
4635 }
4636 
TEST_F(ModulesTest,PrettyPrintConvTranspose)4637 TEST_F(ModulesTest, PrettyPrintConvTranspose) {
4638   ASSERT_EQ(
4639       c10::str(ConvTranspose1d(3, 4, 5)),
4640       "torch::nn::ConvTranspose1d(3, 4, kernel_size=5, stride=1)");
4641 
4642   ASSERT_EQ(
4643       c10::str(ConvTranspose2d(3, 4, 5)),
4644       "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4645   ASSERT_EQ(
4646       c10::str(ConvTranspose2d(ConvTranspose2dOptions(3, 4, 5).stride(2))),
4647       "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4648   {
4649     const auto options =
4650         ConvTranspose2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4651     ASSERT_EQ(
4652         c10::str(ConvTranspose2d(options)),
4653         "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4654   }
4655 
4656   ASSERT_EQ(
4657       c10::str(ConvTranspose3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4658       "torch::nn::ConvTranspose3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4659   {
4660     const auto options =
4661         ConvTranspose3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4662             .stride({1, 2, 3})
4663             .padding(1)
4664             .dilation(0)
4665             .groups(2)
4666             .bias(false)
4667             .padding_mode(torch::kCircular);
4668     ASSERT_EQ(
4669         c10::str(ConvTranspose3d(options)),
4670         "torch::nn::ConvTranspose3d("
4671         "4, "
4672         "4, "
4673         "kernel_size=[5, 6, 7], "
4674         "stride=[1, 2, 3], "
4675         "padding=[1, 1, 1], "
4676         "dilation=[0, 0, 0], "
4677         "groups=2, "
4678         "bias=false, "
4679         "padding_mode=kCircular)");
4680   }
4681 }
4682 
TEST_F(ModulesTest,PrettyPrintUpsample)4683 TEST_F(ModulesTest, PrettyPrintUpsample) {
4684   ASSERT_EQ(
4685       c10::str(
4686           Upsample(UpsampleOptions().size(std::vector<int64_t>({2, 4, 4})))),
4687       "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)");
4688   ASSERT_EQ(
4689       c10::str(Upsample(UpsampleOptions()
4690                             .scale_factor(std::vector<double>({0.5, 1.5}))
4691                             .mode(torch::kBilinear))),
4692       "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)");
4693 }
4694 
TEST_F(ModulesTest,PrettyPrintFold)4695 TEST_F(ModulesTest, PrettyPrintFold) {
4696   ASSERT_EQ(
4697       c10::str(Fold(FoldOptions({2, 2}, {5, 5}))),
4698       "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4699   ASSERT_EQ(
4700       c10::str(Fold(
4701           FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))),
4702       "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4703 }
4704 
TEST_F(ModulesTest,PrettyPrintUnfold)4705 TEST_F(ModulesTest, PrettyPrintUnfold) {
4706   ASSERT_EQ(
4707       c10::str(Unfold(torch::IntArrayRef({2, 4}))),
4708       "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4709   ASSERT_EQ(
4710       c10::str(
4711           Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))),
4712       "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4713 }
4714 
TEST_F(ModulesTest,PrettyPrintMaxPool)4715 TEST_F(ModulesTest, PrettyPrintMaxPool) {
4716   ASSERT_EQ(
4717       c10::str(MaxPool1d(5)),
4718       "torch::nn::MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=false)");
4719   ASSERT_EQ(
4720       c10::str(MaxPool2d(5)),
4721       "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4722   ASSERT_EQ(
4723       c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))),
4724       "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4725   ASSERT_EQ(
4726       c10::str(MaxPool3d(5)),
4727       "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4728   ASSERT_EQ(
4729       c10::str(MaxPool3d(MaxPool3dOptions(5).stride(2))),
4730       "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4731 
4732   const auto options =
4733       MaxPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4734   ASSERT_EQ(
4735       c10::str(MaxPool2d(options)),
4736       "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4737 }
4738 
TEST_F(ModulesTest,PrettyPrintAvgPool)4739 TEST_F(ModulesTest, PrettyPrintAvgPool) {
4740   ASSERT_EQ(
4741       c10::str(AvgPool1d(5)),
4742       "torch::nn::AvgPool1d(kernel_size=5, stride=5, padding=0)");
4743   ASSERT_EQ(
4744       c10::str(AvgPool2d(5)),
4745       "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4746   ASSERT_EQ(
4747       c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))),
4748       "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0])");
4749   ASSERT_EQ(
4750       c10::str(AvgPool3d(5)),
4751       "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0])");
4752   ASSERT_EQ(
4753       c10::str(AvgPool3d(AvgPool3dOptions(5).stride(2))),
4754       "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0])");
4755 
4756   const auto options =
4757       AvgPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4758   ASSERT_EQ(
4759       c10::str(AvgPool2d(options)),
4760       "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0])");
4761 }
4762 
TEST_F(ModulesTest,PrettyPrinFractionalMaxPool)4763 TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) {
4764   ASSERT_EQ(
4765       c10::str(
4766           FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))),
4767       "torch::nn::FractionalMaxPool2d()");
4768   ASSERT_EQ(
4769       c10::str(
4770           FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))),
4771       "torch::nn::FractionalMaxPool3d()");
4772 }
4773 
TEST_F(ModulesTest,PrettyPrintLPPool)4774 TEST_F(ModulesTest, PrettyPrintLPPool) {
4775   ASSERT_EQ(
4776       c10::str(LPPool1d(2, 5)),
4777       "torch::nn::LPPool1d(norm_type=2, kernel_size=5, stride=5, ceil_mode=false)");
4778   ASSERT_EQ(
4779       c10::str(LPPool1d(LPPool1dOptions(1, 2).stride(5).ceil_mode(true))),
4780       "torch::nn::LPPool1d(norm_type=1, kernel_size=2, stride=5, ceil_mode=true)");
4781   ASSERT_EQ(
4782       c10::str(LPPool2d(2, std::vector<int64_t>({1, 2}))),
4783       "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)");
4784   ASSERT_EQ(
4785       c10::str(LPPool2d(LPPool2dOptions(1, std::vector<int64_t>({3, 4}))
4786                             .stride({5, 6})
4787                             .ceil_mode(true))),
4788       "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
4789   ASSERT_EQ(
4790       c10::str(LPPool3d(2, std::vector<int64_t>({1, 2, 3}))),
4791       "torch::nn::LPPool3d(norm_type=2, kernel_size=[1, 2, 3], stride=[1, 2, 3], ceil_mode=false)");
4792   ASSERT_EQ(
4793       c10::str(LPPool3d(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5}))
4794                             .stride({5, 6, 7})
4795                             .ceil_mode(true))),
4796       "torch::nn::LPPool3d(norm_type=1, kernel_size=[3, 4, 5], stride=[5, 6, 7], ceil_mode=true)");
4797 }
4798 
TEST_F(ModulesTest,PrettyPrintAdaptiveMaxPool)4799 TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
4800   ASSERT_EQ(
4801       c10::str(AdaptiveMaxPool1d(5)),
4802       "torch::nn::AdaptiveMaxPool1d(output_size=5)");
4803 
4804   const auto options = AdaptiveMaxPool1dOptions(3);
4805   ASSERT_EQ(
4806       c10::str(AdaptiveMaxPool1d(options)),
4807       "torch::nn::AdaptiveMaxPool1d(output_size=3)");
4808 
4809   ASSERT_EQ(
4810       c10::str(AdaptiveMaxPool2d(5)),
4811       "torch::nn::AdaptiveMaxPool2d(output_size=[5, 5])");
4812   ASSERT_EQ(
4813       c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, 6}))),
4814       "torch::nn::AdaptiveMaxPool2d(output_size=[5, 6])");
4815   ASSERT_EQ(
4816       c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, std::nullopt}))),
4817       "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])");
4818   ASSERT_EQ(
4819       c10::str(AdaptiveMaxPool2d(
4820           AdaptiveMaxPool2dOptions({std::nullopt, std::nullopt}))),
4821       "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])");
4822 
4823   ASSERT_EQ(
4824       c10::str(AdaptiveMaxPool3d(5)),
4825       "torch::nn::AdaptiveMaxPool3d(output_size=[5, 5, 5])");
4826   ASSERT_EQ(
4827       c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))),
4828       "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])");
4829   ASSERT_EQ(
4830       c10::str(
4831           AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, std::nullopt, 7}))),
4832       "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])");
4833   ASSERT_EQ(
4834       c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions(
4835           {std::nullopt, std::nullopt, std::nullopt}))),
4836       "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])");
4837 }
4838 
TEST_F(ModulesTest,PrettyPrintAdaptiveAvgPool)4839 TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
4840   ASSERT_EQ(
4841       c10::str(AdaptiveAvgPool1d(5)),
4842       "torch::nn::AdaptiveAvgPool1d(output_size=5)");
4843 
4844   ASSERT_EQ(
4845       c10::str(AdaptiveAvgPool2d(5)),
4846       "torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
4847   ASSERT_EQ(
4848       c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, 6}))),
4849       "torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
4850   ASSERT_EQ(
4851       c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, std::nullopt}))),
4852       "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])");
4853   ASSERT_EQ(
4854       c10::str(AdaptiveAvgPool2d(
4855           AdaptiveAvgPool2dOptions({std::nullopt, std::nullopt}))),
4856       "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])");
4857 
4858   ASSERT_EQ(
4859       c10::str(AdaptiveAvgPool3d(5)),
4860       "torch::nn::AdaptiveAvgPool3d(output_size=[5, 5, 5])");
4861   ASSERT_EQ(
4862       c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))),
4863       "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])");
4864   ASSERT_EQ(
4865       c10::str(
4866           AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, std::nullopt, 7}))),
4867       "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])");
4868   ASSERT_EQ(
4869       c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions(
4870           {std::nullopt, std::nullopt, std::nullopt}))),
4871       "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])");
4872 }
4873 
TEST_F(ModulesTest,PrettyPrintMaxUnpool)4874 TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
4875   ASSERT_EQ(
4876       c10::str(MaxUnpool1d(5)),
4877       "torch::nn::MaxUnpool1d(kernel_size=5, stride=5, padding=0)");
4878   ASSERT_EQ(
4879       c10::str(MaxUnpool1d(MaxUnpool1dOptions(5).stride(3).padding(1))),
4880       "torch::nn::MaxUnpool1d(kernel_size=5, stride=3, padding=1)");
4881 
4882   ASSERT_EQ(
4883       c10::str(MaxUnpool2d(5)),
4884       "torch::nn::MaxUnpool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4885   ASSERT_EQ(
4886       c10::str(MaxUnpool2d(std::vector<int64_t>{5, 6})),
4887       "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])");
4888   ASSERT_EQ(
4889       c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector<int64_t>{5, 6})
4890                                .stride({3, 4})
4891                                .padding({1, 2}))),
4892       "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])");
4893 }
4894 
TEST_F(ModulesTest,PrettyPrintDropout)4895 TEST_F(ModulesTest, PrettyPrintDropout) {
4896   ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)");
4897   ASSERT_EQ(
4898       c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)");
4899   ASSERT_EQ(
4900       c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))),
4901       "torch::nn::Dropout(p=0.42, inplace=true)");
4902 }
4903 
TEST_F(ModulesTest,PrettyPrintDropout2d)4904 TEST_F(ModulesTest, PrettyPrintDropout2d) {
4905   ASSERT_EQ(
4906       c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)");
4907   ASSERT_EQ(
4908       c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)");
4909   ASSERT_EQ(
4910       c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))),
4911       "torch::nn::Dropout2d(p=0.42, inplace=true)");
4912 }
4913 
TEST_F(ModulesTest,PrettyPrintDropout3d)4914 TEST_F(ModulesTest, PrettyPrintDropout3d) {
4915   ASSERT_EQ(
4916       c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)");
4917   ASSERT_EQ(
4918       c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)");
4919   ASSERT_EQ(
4920       c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))),
4921       "torch::nn::Dropout3d(p=0.42, inplace=true)");
4922 }
4923 
TEST_F(ModulesTest,PrettyPrintFunctional)4924 TEST_F(ModulesTest, PrettyPrintFunctional) {
4925   ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
4926 }
4927 
TEST_F(ModulesTest,PrettyPrintBatchNorm1d)4928 TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
4929   ASSERT_EQ(
4930       c10::str(BatchNorm1d(BatchNorm1dOptions(4)
4931                                .eps(0.5)
4932                                .momentum(0.1)
4933                                .affine(false)
4934                                .track_running_stats(true))),
4935       "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4936 }
4937 
TEST_F(ModulesTest,PrettyPrintBatchNorm2d)4938 TEST_F(ModulesTest, PrettyPrintBatchNorm2d) {
4939   ASSERT_EQ(
4940       c10::str(BatchNorm2d(BatchNorm2dOptions(4)
4941                                .eps(0.5)
4942                                .momentum(0.1)
4943                                .affine(false)
4944                                .track_running_stats(true))),
4945       "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4946 }
4947 
TEST_F(ModulesTest,PrettyPrintBatchNorm3d)4948 TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
4949   ASSERT_EQ(
4950       c10::str(BatchNorm3d(BatchNorm3dOptions(4)
4951                                .eps(0.5)
4952                                .momentum(0.1)
4953                                .affine(false)
4954                                .track_running_stats(true))),
4955       "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4956 }
4957 
TEST_F(ModulesTest,PrettyPrintInstanceNorm1d)4958 TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
4959   ASSERT_EQ(
4960       c10::str(InstanceNorm1d(InstanceNorm1dOptions(4)
4961                                   .eps(0.5)
4962                                   .momentum(0.1)
4963                                   .affine(false)
4964                                   .track_running_stats(true))),
4965       "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4966 }
4967 
TEST_F(ModulesTest,PrettyPrintInstanceNorm2d)4968 TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
4969   ASSERT_EQ(
4970       c10::str(InstanceNorm2d(InstanceNorm2dOptions(4)
4971                                   .eps(0.5)
4972                                   .momentum(0.1)
4973                                   .affine(false)
4974                                   .track_running_stats(true))),
4975       "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4976 }
4977 
TEST_F(ModulesTest,PrettyPrintInstanceNorm3d)4978 TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
4979   ASSERT_EQ(
4980       c10::str(InstanceNorm3d(InstanceNorm3dOptions(4)
4981                                   .eps(0.5)
4982                                   .momentum(0.1)
4983                                   .affine(false)
4984                                   .track_running_stats(true))),
4985       "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4986 }
4987 
TEST_F(ModulesTest,PrettyPrintLayerNorm)4988 TEST_F(ModulesTest, PrettyPrintLayerNorm) {
4989   ASSERT_EQ(
4990       c10::str(LayerNorm(LayerNormOptions({2, 2}))),
4991       "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)");
4992   ASSERT_EQ(
4993       c10::str(LayerNorm(
4994           LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))),
4995       "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
4996 }
4997 
TEST_F(ModulesTest,PrettyPrintGroupNorm)4998 TEST_F(ModulesTest, PrettyPrintGroupNorm) {
4999   ASSERT_EQ(
5000       c10::str(GroupNorm(GroupNormOptions(2, 2))),
5001       "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)");
5002   ASSERT_EQ(
5003       c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))),
5004       "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)");
5005 }
5006 
TEST_F(ModulesTest,PrettyPrintLocalResponseNorm)5007 TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
5008   ASSERT_EQ(
5009       c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
5010       "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
5011   ASSERT_EQ(
5012       c10::str(LocalResponseNorm(
5013           LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
5014       "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
5015 }
5016 
TEST_F(ModulesTest,PrettyPrintEmbedding)5017 TEST_F(ModulesTest, PrettyPrintEmbedding) {
5018   ASSERT_EQ(
5019       c10::str(Embedding(EmbeddingOptions(10, 2))),
5020       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
5021   ASSERT_EQ(
5022       c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),
5023       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
5024   ASSERT_EQ(
5025       c10::str(Embedding(EmbeddingOptions(10, 2)
5026                              .padding_idx(3)
5027                              .max_norm(2)
5028                              .norm_type(2.5)
5029                              .scale_grad_by_freq(true)
5030                              .sparse(true))),
5031       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5032 }
5033 
TEST_F(ModulesTest,PrettyPrintEmbeddingBag)5034 TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
5035   ASSERT_EQ(
5036       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))),
5037       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)");
5038   ASSERT_EQ(
5039       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))),
5040       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)");
5041   ASSERT_EQ(
5042       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5043                                 .max_norm(2)
5044                                 .norm_type(2.5)
5045                                 .scale_grad_by_freq(true)
5046                                 .sparse(true))),
5047       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5048   ASSERT_EQ(
5049       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5050                                 .max_norm(2)
5051                                 .norm_type(2.5)
5052                                 .scale_grad_by_freq(true)
5053                                 .sparse(true)
5054                                 .mode(torch::kSum))),
5055       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)");
5056   ASSERT_EQ(
5057       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5058                                 .max_norm(2)
5059                                 .norm_type(2.5)
5060                                 .scale_grad_by_freq(true)
5061                                 .sparse(true)
5062                                 .mode(torch::kSum)
5063                                 .padding_idx(5))),
5064       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)");
5065 }
5066 
TEST_F(ModulesTest,PrettyPrintL1Loss)5067 TEST_F(ModulesTest, PrettyPrintL1Loss) {
5068   ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()");
5069 }
TEST_F(ModulesTest,PrettyPrintKLDivLoss)5070 TEST_F(ModulesTest, PrettyPrintKLDivLoss) {
5071   ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()");
5072 }
TEST_F(ModulesTest,PrettyPrintMSELoss)5073 TEST_F(ModulesTest, PrettyPrintMSELoss) {
5074   ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()");
5075 }
TEST_F(ModulesTest,PrettyPrintBCELoss)5076 TEST_F(ModulesTest, PrettyPrintBCELoss) {
5077   ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()");
5078 }
TEST_F(ModulesTest,PrettyPrintHingeEmbeddingLoss)5079 TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) {
5080   ASSERT_EQ(
5081       c10::str(HingeEmbeddingLoss(HingeEmbeddingLossOptions().margin(4))),
5082       "torch::nn::HingeEmbeddingLoss(margin=4)");
5083 }
5084 
TEST_F(ModulesTest,PrettyPrintCosineEmbeddingLoss)5085 TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) {
5086   ASSERT_EQ(
5087       c10::str(CosineEmbeddingLoss(CosineEmbeddingLossOptions().margin(0.25))),
5088       "torch::nn::CosineEmbeddingLoss(margin=0.25)");
5089 }
5090 
TEST_F(ModulesTest,PrettyPrintTripletMarginLoss)5091 TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
5092   ASSERT_EQ(
5093       c10::str(TripletMarginLoss(
5094           TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))),
5095       "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
5096 }
5097 
TEST_F(ModulesTest,PrettyPrintTripletMarginWithDistanceLoss)5098 TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
5099   auto distanceOptions = TripletMarginWithDistanceLossOptions()
5100                              .distance_function([&](const torch::Tensor& x,
5101                                                     const torch::Tensor& y) {
5102                                return torch::pairwise_distance(x, y, 2.0, 1e-6);
5103                              })
5104                              .margin(1.5)
5105                              .swap(true)
5106                              .reduction(torch::kMean);
5107   ASSERT_EQ(
5108       c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
5109       "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
5110 }
5111 
TEST_F(ModulesTest,PrettyPrintNLLLoss)5112 TEST_F(ModulesTest, PrettyPrintNLLLoss) {
5113   ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()");
5114 }
5115 
TEST_F(ModulesTest,PrettyPrinCrossEntropyLoss)5116 TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) {
5117   ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()");
5118 }
5119 
TEST_F(ModulesTest,PrettyPrintMultiLabelMarginLoss)5120 TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) {
5121   ASSERT_EQ(
5122       c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()");
5123 }
5124 
TEST_F(ModulesTest,PrettyPrintMultiLabelSoftMarginLoss)5125 TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
5126   ASSERT_EQ(
5127       c10::str(MultiLabelSoftMarginLoss()),
5128       "torch::nn::MultiLabelSoftMarginLoss()");
5129 }
5130 
TEST_F(ModulesTest,PrettyPrintSoftMarginLoss)5131 TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) {
5132   ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()");
5133 }
5134 
TEST_F(ModulesTest,PrettyPrintCosineSimilarity)5135 TEST_F(ModulesTest, PrettyPrintCosineSimilarity) {
5136   ASSERT_EQ(
5137       c10::str(CosineSimilarity()),
5138       "torch::nn::CosineSimilarity(dim=1, eps=1e-08)");
5139   ASSERT_EQ(
5140       c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))),
5141       "torch::nn::CosineSimilarity(dim=0, eps=0.5)");
5142 }
5143 
TEST_F(ModulesTest,PrettyPrintPairwiseDistance)5144 TEST_F(ModulesTest, PrettyPrintPairwiseDistance) {
5145   ASSERT_EQ(
5146       c10::str(PairwiseDistance()),
5147       "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)");
5148   ASSERT_EQ(
5149       c10::str(PairwiseDistance(
5150           PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))),
5151       "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)");
5152 }
5153 
TEST_F(ModulesTest,PrettyPrintReflectionPad)5154 TEST_F(ModulesTest, PrettyPrintReflectionPad) {
5155   ASSERT_EQ(
5156       c10::str(ReflectionPad1d(ReflectionPad1dOptions(2))),
5157       "torch::nn::ReflectionPad1d(padding=[2, 2])");
5158   ASSERT_EQ(
5159       c10::str(ReflectionPad1d(ReflectionPad1dOptions({3, 1}))),
5160       "torch::nn::ReflectionPad1d(padding=[3, 1])");
5161   ASSERT_EQ(
5162       c10::str(ReflectionPad2d(ReflectionPad2dOptions(2))),
5163       "torch::nn::ReflectionPad2d(padding=[2, 2, 2, 2])");
5164   ASSERT_EQ(
5165       c10::str(ReflectionPad2d(ReflectionPad2dOptions({1, 1, 2, 0}))),
5166       "torch::nn::ReflectionPad2d(padding=[1, 1, 2, 0])");
5167 }
5168 
TEST_F(ModulesTest,PrettyPrintReplicationPad)5169 TEST_F(ModulesTest, PrettyPrintReplicationPad) {
5170   ASSERT_EQ(
5171       c10::str(ReplicationPad1d(ReplicationPad1dOptions(2))),
5172       "torch::nn::ReplicationPad1d(padding=[2, 2])");
5173   ASSERT_EQ(
5174       c10::str(ReplicationPad1d(ReplicationPad1dOptions({3, 1}))),
5175       "torch::nn::ReplicationPad1d(padding=[3, 1])");
5176   ASSERT_EQ(
5177       c10::str(ReplicationPad2d(ReplicationPad2dOptions(2))),
5178       "torch::nn::ReplicationPad2d(padding=[2, 2, 2, 2])");
5179   ASSERT_EQ(
5180       c10::str(ReplicationPad2d(ReplicationPad2dOptions({1, 1, 2, 0}))),
5181       "torch::nn::ReplicationPad2d(padding=[1, 1, 2, 0])");
5182   ASSERT_EQ(
5183       c10::str(ReplicationPad3d(ReplicationPad3dOptions(1))),
5184       "torch::nn::ReplicationPad3d(padding=[1, 1, 1, 1, 1, 1])");
5185   ASSERT_EQ(
5186       c10::str(ReplicationPad3d(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}))),
5187       "torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
5188 }
5189 
TEST_F(ModulesTest,PrettyPrintZeroPad)5190 TEST_F(ModulesTest, PrettyPrintZeroPad) {
5191   ASSERT_EQ(
5192       c10::str(ZeroPad1d(ZeroPad1dOptions(2))),
5193       "torch::nn::ZeroPad1d(padding=[2, 2])");
5194   ASSERT_EQ(
5195       c10::str(ZeroPad1d(ZeroPad1dOptions({3, 1}))),
5196       "torch::nn::ZeroPad1d(padding=[3, 1])");
5197   ASSERT_EQ(
5198       c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
5199       "torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
5200   ASSERT_EQ(
5201       c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
5202       "torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
5203   ASSERT_EQ(
5204       c10::str(ZeroPad3d(ZeroPad3dOptions(1))),
5205       "torch::nn::ZeroPad3d(padding=[1, 1, 1, 1, 1, 1])");
5206   ASSERT_EQ(
5207       c10::str(ZeroPad3d(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}))),
5208       "torch::nn::ZeroPad3d(padding=[1, 2, 1, 2, 1, 2])");
5209 }
5210 
TEST_F(ModulesTest,PrettyPrintConstantPad)5211 TEST_F(ModulesTest, PrettyPrintConstantPad) {
5212   ASSERT_EQ(
5213       c10::str(ConstantPad1d(ConstantPad1dOptions(2, 3.5))),
5214       "torch::nn::ConstantPad1d(padding=[2, 2], value=3.5)");
5215   ASSERT_EQ(
5216       c10::str(ConstantPad1d(ConstantPad1dOptions({3, 1}, 3.5))),
5217       "torch::nn::ConstantPad1d(padding=[3, 1], value=3.5)");
5218   ASSERT_EQ(
5219       c10::str(ConstantPad2d(ConstantPad2dOptions(2, 3.5))),
5220       "torch::nn::ConstantPad2d(padding=[2, 2, 2, 2], value=3.5)");
5221   ASSERT_EQ(
5222       c10::str(ConstantPad2d(ConstantPad2dOptions({3, 0, 2, 1}, 3.5))),
5223       "torch::nn::ConstantPad2d(padding=[3, 0, 2, 1], value=3.5)");
5224   ASSERT_EQ(
5225       c10::str(ConstantPad3d(ConstantPad3dOptions(1, 3.5))),
5226       "torch::nn::ConstantPad3d(padding=[1, 1, 1, 1, 1, 1], value=3.5)");
5227   ASSERT_EQ(
5228       c10::str(ConstantPad3d(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5))),
5229       "torch::nn::ConstantPad3d(padding=[1, 2, 1, 2, 1, 2], value=3.5)");
5230 }
5231 
TEST_F(ModulesTest,PrettyPrintNestedModel)5232 TEST_F(ModulesTest, PrettyPrintNestedModel) {
5233   struct InnerTestModule : torch::nn::Module {
5234     InnerTestModule()
5235         : torch::nn::Module("InnerTestModule"),
5236           fc(register_module("fc", torch::nn::Linear(3, 4))),
5237           table(register_module("table", torch::nn::Embedding(10, 2))) {}
5238 
5239     torch::nn::Linear fc;
5240     torch::nn::Embedding table;
5241   };
5242 
5243   struct TestModule : torch::nn::Module {
5244     TestModule()
5245         : torch::nn::Module("TestModule"),
5246           fc(register_module("fc", torch::nn::Linear(4, 5))),
5247           table(register_module(
5248               "table",
5249               torch::nn::Embedding(EmbeddingOptions(10, 2)))),
5250           inner(register_module("inner", std::make_shared<InnerTestModule>())) {
5251     }
5252 
5253     torch::nn::Linear fc;
5254     torch::nn::Embedding table;
5255     std::shared_ptr<InnerTestModule> inner;
5256   };
5257 
5258   ASSERT_EQ(
5259       c10::str(TestModule{}),
5260       "TestModule(\n"
5261       "  (fc): torch::nn::Linear(in_features=4, out_features=5, bias=true)\n"
5262       "  (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5263       "  (inner): InnerTestModule(\n"
5264       "    (fc): torch::nn::Linear(in_features=3, out_features=4, bias=true)\n"
5265       "    (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5266       "  )\n"
5267       ")");
5268 }
5269 
TEST_F(ModulesTest,PrettyPrintELU)5270 TEST_F(ModulesTest, PrettyPrintELU) {
5271   ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)");
5272   ASSERT_EQ(
5273       c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))),
5274       "torch::nn::ELU(alpha=42.42, inplace=true)");
5275 }
5276 
TEST_F(ModulesTest,PrettyPrintSELU)5277 TEST_F(ModulesTest, PrettyPrintSELU) {
5278   ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()");
5279   ASSERT_EQ(
5280       c10::str(SELU(SELUOptions().inplace(true))),
5281       "torch::nn::SELU(inplace=true)");
5282 }
5283 
TEST_F(ModulesTest,PrettyPrintGLU)5284 TEST_F(ModulesTest, PrettyPrintGLU) {
5285   ASSERT_EQ(c10::str(GLU()), "torch::nn::GLU(dim=-1)");
5286   ASSERT_EQ(c10::str(GLU(1)), "torch::nn::GLU(dim=1)");
5287 }
5288 
TEST_F(ModulesTest,PrettyPrintHardshrink)5289 TEST_F(ModulesTest, PrettyPrintHardshrink) {
5290   ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)");
5291   ASSERT_EQ(
5292       c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))),
5293       "torch::nn::Hardshrink(42.42)");
5294 }
5295 
TEST_F(ModulesTest,PrettyPrintHardtanh)5296 TEST_F(ModulesTest, PrettyPrintHardtanh) {
5297   ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)");
5298   ASSERT_EQ(
5299       c10::str(Hardtanh(
5300           HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))),
5301       "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)");
5302 }
5303 
TEST_F(ModulesTest,PrettyPrintLeakyReLU)5304 TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
5305   ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)");
5306   ASSERT_EQ(
5307       c10::str(
5308           LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))),
5309       "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
5310 }
5311 
TEST_F(ModulesTest,PrettyPrintLogSigmoid)5312 TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
5313   ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
5314 }
5315 
TEST_F(ModulesTest,PrettyPrintSoftmax)5316 TEST_F(ModulesTest, PrettyPrintSoftmax) {
5317   ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)");
5318 }
5319 
TEST_F(ModulesTest,PrettyPrintSoftmin)5320 TEST_F(ModulesTest, PrettyPrintSoftmin) {
5321   ASSERT_EQ(c10::str(Softmin(SoftminOptions(1))), "torch::nn::Softmin(dim=1)");
5322 }
5323 
TEST_F(ModulesTest,PrettyPrintLogSoftmax)5324 TEST_F(ModulesTest, PrettyPrintLogSoftmax) {
5325   ASSERT_EQ(
5326       c10::str(LogSoftmax(LogSoftmaxOptions(1))),
5327       "torch::nn::LogSoftmax(dim=1)");
5328 }
5329 
TEST_F(ModulesTest,PrettyPrintSoftmax2d)5330 TEST_F(ModulesTest, PrettyPrintSoftmax2d) {
5331   ASSERT_EQ(c10::str(Softmax2d()), "torch::nn::Softmax2d()");
5332 }
5333 
TEST_F(ModulesTest,PrettyPrintPReLU)5334 TEST_F(ModulesTest, PrettyPrintPReLU) {
5335   ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)");
5336   ASSERT_EQ(
5337       c10::str(PReLU(PReLUOptions().num_parameters(42))),
5338       "torch::nn::PReLU(num_parameters=42)");
5339 }
5340 
TEST_F(ModulesTest,PrettyPrintReLU)5341 TEST_F(ModulesTest, PrettyPrintReLU) {
5342   ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()");
5343   ASSERT_EQ(
5344       c10::str(ReLU(ReLUOptions().inplace(true))),
5345       "torch::nn::ReLU(inplace=true)");
5346   ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)");
5347 }
5348 
TEST_F(ModulesTest,PrettyPrintReLU6)5349 TEST_F(ModulesTest, PrettyPrintReLU6) {
5350   ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()");
5351   ASSERT_EQ(
5352       c10::str(ReLU6(ReLU6Options().inplace(true))),
5353       "torch::nn::ReLU6(inplace=true)");
5354   ASSERT_EQ(
5355       c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)");
5356 }
5357 
TEST_F(ModulesTest,PrettyPrintRReLU)5358 TEST_F(ModulesTest, PrettyPrintRReLU) {
5359   ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)");
5360   ASSERT_EQ(
5361       c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))),
5362       "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)");
5363 }
5364 
TEST_F(ModulesTest,PrettyPrintCELU)5365 TEST_F(ModulesTest, PrettyPrintCELU) {
5366   ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)");
5367   ASSERT_EQ(
5368       c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))),
5369       "torch::nn::CELU(alpha=42.42, inplace=true)");
5370 }
5371 
TEST_F(ModulesTest,PrettyPrintSigmoid)5372 TEST_F(ModulesTest, PrettyPrintSigmoid) {
5373   ASSERT_EQ(c10::str(Sigmoid()), "torch::nn::Sigmoid()");
5374 }
5375 
TEST_F(ModulesTest,PrettyPrintPixelShuffle)5376 TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
5377   ASSERT_EQ(
5378       c10::str(PixelShuffle(PixelShuffleOptions(5))),
5379       "torch::nn::PixelShuffle(upscale_factor=5)");
5380 }
5381 
TEST_F(ModulesTest,PrettyPrintPixelUnshuffle)5382 TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
5383   ASSERT_EQ(
5384       c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
5385       "torch::nn::PixelUnshuffle(downscale_factor=5)");
5386 }
5387 
TEST_F(ModulesTest,PrettyPrintSoftplus)5388 TEST_F(ModulesTest, PrettyPrintSoftplus) {
5389   ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)");
5390   ASSERT_EQ(
5391       c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))),
5392       "torch::nn::Softplus(beta=0.24, threshold=42.42)");
5393 }
5394 
TEST_F(ModulesTest,PrettyPrintSoftshrink)5395 TEST_F(ModulesTest, PrettyPrintSoftshrink) {
5396   ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)");
5397   ASSERT_EQ(
5398       c10::str(Softshrink(SoftshrinkOptions(42.42))),
5399       "torch::nn::Softshrink(42.42)");
5400 }
5401 
TEST_F(ModulesTest,PrettyPrintSoftsign)5402 TEST_F(ModulesTest, PrettyPrintSoftsign) {
5403   ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()");
5404 }
5405 
TEST_F(ModulesTest,PrettyPrintTanh)5406 TEST_F(ModulesTest, PrettyPrintTanh) {
5407   ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()");
5408 }
5409 
TEST_F(ModulesTest,PrettyPrintTanhshrink)5410 TEST_F(ModulesTest, PrettyPrintTanhshrink) {
5411   ASSERT_EQ(c10::str(Tanhshrink()), "torch::nn::Tanhshrink()");
5412 }
5413 
TEST_F(ModulesTest,PrettyPrintThreshold)5414 TEST_F(ModulesTest, PrettyPrintThreshold) {
5415   ASSERT_EQ(
5416       c10::str(Threshold(24.24, 42.42)),
5417       "torch::nn::Threshold(threshold=24.24, value=42.42)");
5418   ASSERT_EQ(
5419       c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))),
5420       "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)");
5421 }
5422 
TEST_F(ModulesTest,PrettyPrintCTCLoss)5423 TEST_F(ModulesTest, PrettyPrintCTCLoss) {
5424   ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()");
5425   ASSERT_EQ(
5426       c10::str(
5427           CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction(
5428               torch::kSum))),
5429       "torch::nn::CTCLoss()");
5430 }
5431 
TEST_F(ModulesTest,PrettyPrintPoissonNLLLoss)5432 TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
5433   ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
5434   ASSERT_EQ(
5435       c10::str(PoissonNLLLoss(PoissonNLLLossOptions()
5436                                   .log_input(false)
5437                                   .full(true)
5438                                   .eps(0.42)
5439                                   .reduction(torch::kSum))),
5440       "torch::nn::PoissonNLLLoss()");
5441 }
5442 
TEST_F(ModulesTest,PrettyPrintMarginRankingLoss)5443 TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) {
5444   ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()");
5445   ASSERT_EQ(
5446       c10::str(MarginRankingLoss(
5447           MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))),
5448       "torch::nn::MarginRankingLoss()");
5449 }
5450 
TEST_F(ModulesTest,PrettyPrintCrossMapLRN2d)5451 TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) {
5452   ASSERT_EQ(
5453       c10::str(CrossMapLRN2d(4)),
5454       "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)");
5455   ASSERT_EQ(
5456       c10::str(
5457           CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))),
5458       "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)");
5459 }
5460 
TEST_F(ModulesTest,PrettyPrintAlphaDropout)5461 TEST_F(ModulesTest, PrettyPrintAlphaDropout) {
5462   ASSERT_EQ(
5463       c10::str(AlphaDropout()),
5464       "torch::nn::AlphaDropout(p=0.5, inplace=false)");
5465   ASSERT_EQ(
5466       c10::str(AlphaDropout(AlphaDropoutOptions(0.2))),
5467       "torch::nn::AlphaDropout(p=0.2, inplace=false)");
5468   ASSERT_EQ(
5469       c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))),
5470       "torch::nn::AlphaDropout(p=0.2, inplace=true)");
5471 }
5472 
TEST_F(ModulesTest,PrettyPrintFeatureAlphaDropout)5473 TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) {
5474   ASSERT_EQ(
5475       c10::str(FeatureAlphaDropout()),
5476       "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)");
5477   ASSERT_EQ(
5478       c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))),
5479       "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)");
5480   ASSERT_EQ(
5481       c10::str(
5482           FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))),
5483       "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)");
5484 }
5485 
TEST_F(ModulesTest,PrettyPrintBCEWithLogitsLoss)5486 TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) {
5487   ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()");
5488   ASSERT_EQ(
5489       c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions()
5490                                      .weight(torch::ones({3, 3}))
5491                                      .pos_weight(torch::ones({3, 3}))
5492                                      .reduction(torch::kSum))),
5493       "torch::nn::BCEWithLogitsLoss()");
5494 }
5495 
TEST_F(ModulesTest,PrettyPrintMultiheadAttention)5496 TEST_F(ModulesTest, PrettyPrintMultiheadAttention) {
5497   ASSERT_EQ(
5498       c10::str(MultiheadAttention(20, 10)),
5499       "torch::nn::MultiheadAttention(\n  (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)");
5500   ASSERT_EQ(
5501       c10::str(
5502           MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
5503       "torch::nn::MultiheadAttention(\n  (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
5504 }
5505 
TEST_F(ModulesTest,PrettyPrintRNNCell)5506 TEST_F(ModulesTest, PrettyPrintRNNCell) {
5507   ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)");
5508   ASSERT_EQ(
5509       c10::str(RNNCell(
5510           RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))),
5511       "torch::nn::RNNCell(20, 10, bias=false)");
5512   ASSERT_EQ(
5513       c10::str(RNNCell(
5514           RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))),
5515       "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)");
5516 }
5517 
TEST_F(ModulesTest,PrettyPrintLSTMCell)5518 TEST_F(ModulesTest, PrettyPrintLSTMCell) {
5519   ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)");
5520   ASSERT_EQ(
5521       c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))),
5522       "torch::nn::LSTMCell(20, 10, bias=false)");
5523 }
5524 
TEST_F(ModulesTest,PrettyPrintGRUCell)5525 TEST_F(ModulesTest, PrettyPrintGRUCell) {
5526   ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)");
5527   ASSERT_EQ(
5528       c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))),
5529       "torch::nn::GRUCell(20, 10, bias=false)");
5530 }
5531 
TEST_F(ModulesTest,PrettyPrintAdaptiveLogSoftmaxWithLoss)5532 TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
5533   {
5534     AdaptiveLogSoftmaxWithLoss asfm(
5535         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
5536     ASSERT_EQ(
5537         c10::str(asfm),
5538         "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5539         "  (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
5540         "  (tail): torch::nn::ModuleList(\n"
5541         "    (0): torch::nn::Sequential(\n"
5542         "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5543         "      (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
5544         "    )\n"
5545         "  )\n"
5546         ")");
5547   }
5548   {
5549     AdaptiveLogSoftmaxWithLoss asfm(
5550         AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
5551             .div_value(2.)
5552             .head_bias(true));
5553     ASSERT_EQ(
5554         c10::str(asfm),
5555         "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5556         "  (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
5557         "  (tail): torch::nn::ModuleList(\n"
5558         "    (0): torch::nn::Sequential(\n"
5559         "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5560         "      (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
5561         "    )\n"
5562         "    (1): torch::nn::Sequential(\n"
5563         "      (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
5564         "      (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
5565         "    )\n"
5566         "  )\n"
5567         ")");
5568   }
5569 }
5570