1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5
6 #include <test/cpp/api/support.h>
7
8 #include <torch/expanding_array.h>
9 #include <torch/nn/functional/activation.h>
10 #include <torch/nn/options/activation.h>
11 #include <limits>
12 #include <random>
13
14 using namespace torch::nn;
15 using namespace torch::test;
16
17 class TestModel : public torch::nn::Module {
18 public:
TestModel()19 TestModel()
20 : l1(register_module("l1", Linear(10, 3))),
21 l2(register_module("l2", Linear(3, 5))),
22 l3(register_module("l3", Linear(5, 100))) {}
23
24 Linear l1, l2, l3;
25 };
26
27 class NestedModel : public torch::nn::Module {
28 public:
NestedModel()29 NestedModel()
30 : param_(register_parameter("param", torch::empty({3, 2, 21}))),
31 l1(register_module("l1", Linear(5, 20))),
32 t(register_module("test", std::make_shared<TestModel>())) {}
33
34 torch::Tensor param_;
35 Linear l1;
36 std::shared_ptr<TestModel> t;
37 };
38
39 struct ModulesTest : torch::test::SeedingFixture {};
40
TEST_F(ModulesTest,Conv1d)41 TEST_F(ModulesTest, Conv1d) {
42 Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
43 model->weight.set_data(
44 torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3}));
45 auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
46 .reshape({2, 3, 5});
47 auto y = model(x);
48 auto expected = torch::tensor(
49 {{{312., 348., 384.}, {798., 915., 1032.}},
50
51 {{852., 888., 924.}, {2553., 2670., 2787.}}},
52 torch::kFloat);
53 ASSERT_TRUE(torch::allclose(y, expected));
54
55 torch::Tensor s = y.sum();
56 s.backward();
57 ASSERT_EQ(s.ndimension(), 0);
58 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
59 }
60
TEST_F(ModulesTest,Conv1dSameStrided)61 TEST_F(ModulesTest, Conv1dSameStrided) {
62 auto options = Conv1dOptions(3, 2, 3);
63 options.stride(1).padding(torch::kSame);
64 Conv1d model_valid(options);
65 ASSERT_THROWS_WITH(
66 [&] { Conv1d model_invalid(options.stride(2)); }(),
67 "padding='same' is not supported for strided convolutions");
68 }
69
TEST_F(ModulesTest,Conv1dIvalidArg)70 TEST_F(ModulesTest, Conv1dIvalidArg) {
71 auto options = Conv1dOptions(3, 2, 3).groups(-1);
72 ASSERT_THROWS_WITH(
73 Conv1d(options), "in_channels, groups and out_channels must");
74 }
75
TEST_F(ModulesTest,Conv2dEven)76 TEST_F(ModulesTest, Conv2dEven) {
77 Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
78 model->weight.set_data(
79 torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3}));
80 auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
81 .reshape({1, 3, 5, 5});
82 auto y = model(x);
83 auto expected = torch::tensor(
84 {{{{15219., 15570., 15921.},
85 {16974., 17325., 17676.},
86 {18729., 19080., 19431.}},
87
88 {{37818., 38898., 39978.},
89 {43218., 44298., 45378.},
90 {48618., 49698., 50778.}}}},
91 torch::kFloat);
92 ASSERT_TRUE(torch::allclose(y, expected));
93
94 torch::Tensor s = y.sum();
95 s.backward();
96 ASSERT_EQ(s.ndimension(), 0);
97 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
98 }
99
TEST_F(ModulesTest,Conv2dUneven)100 TEST_F(ModulesTest, Conv2dUneven) {
101 Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
102 model->weight.set_data(
103 torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2}));
104 auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
105 .reshape({1, 3, 5, 4});
106 auto y = model(x);
107 auto expected = torch::tensor(
108 {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
109
110 {{13227., 13704., 14181.},
111 {15135., 15612., 16089.},
112 {17043., 17520., 17997.}}}},
113 torch::kFloat);
114 ASSERT_TRUE(torch::allclose(y, expected));
115
116 torch::Tensor s = y.sum();
117 s.backward();
118 ASSERT_EQ(s.ndimension(), 0);
119 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
120 }
121
TEST_F(ModulesTest,Conv2dSameStrided)122 TEST_F(ModulesTest, Conv2dSameStrided) {
123 auto options = Conv2dOptions(3, 2, {3, 4});
124 options.stride(1).padding(torch::kSame);
125 Conv2d model_valid(options);
126 ASSERT_THROWS_WITH(
127 [&] { Conv2d model_invalid(options.stride(2)); }(),
128 "padding='same' is not supported for strided convolutions");
129 ASSERT_THROWS_WITH(
130 [&] {
131 Conv2d model_invalid(options.stride({1, 2}));
132 }(),
133 "padding='same' is not supported for strided convolutions");
134 }
135
TEST_F(ModulesTest,Conv3d)136 TEST_F(ModulesTest, Conv3d) {
137 Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
138 model->weight.set_data(
139 torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3}));
140 auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
141 .reshape({1, 3, 5, 5, 5});
142 auto y = model(x);
143 auto expected = torch::tensor(
144 {{{{{700704., 703944., 707184.},
145 {716904., 720144., 723384.},
146 {733104., 736344., 739584.}},
147
148 {{781704., 784944., 788184.},
149 {797904., 801144., 804384.},
150 {814104., 817344., 820584.}},
151
152 {{862704., 865944., 869184.},
153 {878904., 882144., 885384.},
154 {895104., 898344., 901584.}}},
155
156 {{{1724220., 1734021., 1743822.},
157 {1773225., 1783026., 1792827.},
158 {1822230., 1832031., 1841832.}},
159
160 {{1969245., 1979046., 1988847.},
161 {2018250., 2028051., 2037852.},
162 {2067255., 2077056., 2086857.}},
163
164 {{2214270., 2224071., 2233872.},
165 {2263275., 2273076., 2282877.},
166 {2312280., 2322081., 2331882.}}}}},
167 torch::kFloat);
168 ASSERT_TRUE(torch::allclose(y, expected));
169
170 torch::Tensor s = y.sum();
171 s.backward();
172 ASSERT_EQ(s.ndimension(), 0);
173 ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
174 }
175
TEST_F(ModulesTest,Conv3dSameStrided)176 TEST_F(ModulesTest, Conv3dSameStrided) {
177 auto options = Conv3dOptions(3, 2, {3, 4, 5});
178 options.stride(1).padding(torch::kSame);
179 Conv3d model_valid(options);
180 ASSERT_THROWS_WITH(
181 [&] { Conv3d model_invalid(options.stride(2)); }(),
182 "padding='same' is not supported for strided convolutions");
183 ASSERT_THROWS_WITH(
184 [&] {
185 Conv3d model_invalid(options.stride({1, 2, 1}));
186 }(),
187 "padding='same' is not supported for strided convolutions");
188 }
189
TEST_F(ModulesTest,ConvTranspose1d)190 TEST_F(ModulesTest, ConvTranspose1d) {
191 ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
192 model->weight.set_data(torch::arange(18.).view({2, 3, 3}));
193 auto x = torch::arange(20.).reshape({2, 2, 5});
194 auto y = model(x);
195 auto expected = torch::tensor(
196 {{{45., 104., 179., 212., 245., 188., 107.},
197 {60., 140., 242., 293., 344., 260., 146.},
198 {75., 176., 305., 374., 443., 332., 185.}},
199 {{135., 304., 509., 542., 575., 428., 237.},
200 {210., 460., 752., 803., 854., 620., 336.},
201 {285., 616., 995., 1064., 1133., 812., 435.}}});
202 ASSERT_TRUE(torch::allclose(y, expected));
203
204 torch::Tensor s = y.sum();
205 s.backward();
206 ASSERT_EQ(s.ndimension(), 0);
207 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
208 }
209
TEST_F(ModulesTest,ConvTranspose2dEven)210 TEST_F(ModulesTest, ConvTranspose2dEven) {
211 ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false));
212 model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3}));
213 auto x = torch::arange(50.).view({1, 2, 5, 5});
214 auto y = model(x);
215 auto expected = torch::tensor(
216 {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
217 {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
218 {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
219 {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
220 {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
221 {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
222 {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
223 {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
224 {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
225 {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
226 {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
227 {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
228 {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
229 {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
230 {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
231 {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
232 {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
233 {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
234 {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
235 {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
236 {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
237 ASSERT_TRUE(torch::allclose(y, expected));
238
239 torch::Tensor s = y.sum();
240 s.backward();
241 ASSERT_EQ(s.ndimension(), 0);
242 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
243 }
244
TEST_F(ModulesTest,ConvTranspose2dUneven)245 TEST_F(ModulesTest, ConvTranspose2dUneven) {
246 ConvTranspose2d model(
247 ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
248 model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2}));
249 auto x = torch::arange(40.).view({1, 2, 5, 4});
250 auto y = model(x);
251 auto expected = torch::tensor(
252 {{{{360., 758., 796., 834., 440.},
253 {832., 1752., 1836., 1920., 1012.},
254 {1432., 3014., 3152., 3290., 1732.},
255 {1696., 3566., 3704., 3842., 2020.},
256 {1960., 4118., 4256., 4394., 2308.},
257 {1504., 3152., 3252., 3352., 1756.},
258 {856., 1790., 1844., 1898., 992.}},
259 {{480., 1010., 1072., 1134., 596.},
260 {1120., 2352., 2484., 2616., 1372.},
261 {1936., 4058., 4268., 4478., 2344.},
262 {2344., 4898., 5108., 5318., 2776.},
263 {2752., 5738., 5948., 6158., 3208.},
264 {2080., 4328., 4476., 4624., 2404.},
265 {1168., 2426., 2504., 2582., 1340.}},
266 {{600., 1262., 1348., 1434., 752.},
267 {1408., 2952., 3132., 3312., 1732.},
268 {2440., 5102., 5384., 5666., 2956.},
269 {2992., 6230., 6512., 6794., 3532.},
270 {3544., 7358., 7640., 7922., 4108.},
271 {2656., 5504., 5700., 5896., 3052.},
272 {1480., 3062., 3164., 3266., 1688.}}}});
273 ASSERT_TRUE(torch::allclose(y, expected));
274
275 torch::Tensor s = y.sum();
276 s.backward();
277 ASSERT_EQ(s.ndimension(), 0);
278 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
279 }
280
TEST_F(ModulesTest,ConvTranspose3d)281 TEST_F(ModulesTest, ConvTranspose3d) {
282 ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false));
283 model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
284 auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
285 auto y = model(x);
286 auto expected = torch::tensor(
287 {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
288 {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
289 {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
290 {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
291 {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
292 {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
293 ASSERT_TRUE(torch::allclose(y, expected));
294
295 torch::Tensor s = y.sum();
296 s.backward();
297 ASSERT_EQ(s.ndimension(), 0);
298 ASSERT_TRUE(model->weight.grad().numel() == 2 * 2 * 2 * 2 * 2);
299 }
300
TEST_F(ModulesTest,MaxPool1d)301 TEST_F(ModulesTest, MaxPool1d) {
302 MaxPool1d model(MaxPool1dOptions(3).stride(2));
303 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
304 auto y = model(x);
305 torch::Tensor s = y.sum();
306
307 s.backward();
308 ASSERT_EQ(y.ndimension(), 3);
309 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
310 ASSERT_EQ(s.ndimension(), 0);
311 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
312 }
313
TEST_F(ModulesTest,MaxPool1dReturnIndices)314 TEST_F(ModulesTest, MaxPool1dReturnIndices) {
315 MaxPool1d model(MaxPool1dOptions(3).stride(2));
316 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
317 auto [y, indices] = model->forward_with_indices(x);
318
319 ASSERT_EQ(y.dim(), 3);
320 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
321 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
322
323 ASSERT_TRUE(
324 torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong)));
325 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 2}));
326 }
327
TEST_F(ModulesTest,MaxPool2dEven)328 TEST_F(ModulesTest, MaxPool2dEven) {
329 MaxPool2d model(MaxPool2dOptions(3).stride(2));
330 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
331 auto y = model(x);
332 torch::Tensor s = y.sum();
333
334 s.backward();
335 ASSERT_EQ(y.ndimension(), 3);
336 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
337 ASSERT_EQ(s.ndimension(), 0);
338 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
339 }
340
TEST_F(ModulesTest,MaxPool2dUneven)341 TEST_F(ModulesTest, MaxPool2dUneven) {
342 MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
343 auto x = torch::ones({2, 5, 4}, torch::requires_grad());
344 auto y = model(x);
345 torch::Tensor s = y.sum();
346
347 s.backward();
348 ASSERT_EQ(y.ndimension(), 3);
349 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
350 ASSERT_EQ(s.ndimension(), 0);
351 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
352 }
353
TEST_F(ModulesTest,MaxPool2dReturnIndices)354 TEST_F(ModulesTest, MaxPool2dReturnIndices) {
355 MaxPool2d model(MaxPool2dOptions(3).stride(2));
356 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
357 auto [y, indices] = model->forward_with_indices(x);
358
359 ASSERT_EQ(y.dim(), 3);
360 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
361 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
362 ASSERT_TRUE(torch::allclose(
363 indices,
364 torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong)));
365 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
366 }
367
TEST_F(ModulesTest,MaxPool3d)368 TEST_F(ModulesTest, MaxPool3d) {
369 MaxPool3d model(MaxPool3dOptions(3).stride(2));
370 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
371 auto y = model(x);
372 torch::Tensor s = y.sum();
373
374 s.backward();
375 ASSERT_EQ(y.ndimension(), 4);
376 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
377 ASSERT_EQ(s.ndimension(), 0);
378 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
379 }
380
TEST_F(ModulesTest,MaxPool3dReturnIndices)381 TEST_F(ModulesTest, MaxPool3dReturnIndices) {
382 MaxPool3d model(MaxPool3dOptions(3).stride(2));
383 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
384 auto [y, indices] = model->forward_with_indices(x);
385
386 ASSERT_EQ(y.dim(), 4);
387 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
388 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
389
390 ASSERT_TRUE(torch::allclose(
391 indices,
392 torch::tensor(
393 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
394 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}},
395 torch::kLong)));
396 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
397 }
398
TEST_F(ModulesTest,AvgPool1d)399 TEST_F(ModulesTest, AvgPool1d) {
400 AvgPool1d model(AvgPool1dOptions(3).stride(2));
401 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
402 auto y = model(x);
403 torch::Tensor s = y.sum();
404
405 s.backward();
406 ASSERT_EQ(y.ndimension(), 3);
407 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
408 ASSERT_EQ(s.ndimension(), 0);
409 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
410 }
411
TEST_F(ModulesTest,AvgPool2dEven)412 TEST_F(ModulesTest, AvgPool2dEven) {
413 AvgPool2d model(AvgPool2dOptions(3).stride(2));
414 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
415 auto y = model(x);
416 torch::Tensor s = y.sum();
417
418 s.backward();
419 ASSERT_EQ(y.ndimension(), 3);
420 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
421 ASSERT_EQ(s.ndimension(), 0);
422 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
423 }
424
TEST_F(ModulesTest,AvgPool2dUneven)425 TEST_F(ModulesTest, AvgPool2dUneven) {
426 AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2}));
427 auto x = torch::ones({2, 5, 4}, torch::requires_grad());
428 auto y = model(x);
429 torch::Tensor s = y.sum();
430
431 s.backward();
432 ASSERT_EQ(y.ndimension(), 3);
433 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
434 ASSERT_EQ(s.ndimension(), 0);
435 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
436 }
437
TEST_F(ModulesTest,AvgPool3d)438 TEST_F(ModulesTest, AvgPool3d) {
439 AvgPool3d model(AvgPool3dOptions(3).stride(2));
440 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
441 auto y = model(x);
442 torch::Tensor s = y.sum();
443
444 s.backward();
445 ASSERT_EQ(y.ndimension(), 4);
446 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
447 ASSERT_EQ(s.ndimension(), 0);
448 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
449 }
450
TEST_F(ModulesTest,FractionalMaxPool2d)451 TEST_F(ModulesTest, FractionalMaxPool2d) {
452 FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
453 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
454 auto y = model(x);
455 torch::Tensor s = y.sum();
456
457 s.backward();
458 ASSERT_EQ(y.ndimension(), 3);
459 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
460 ASSERT_EQ(s.ndimension(), 0);
461 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
462 }
463
TEST_F(ModulesTest,FractionalMaxPool2dReturnIndices)464 TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) {
465 FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
466 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
467 auto [y, indices] = model->forward_with_indices(x);
468
469 ASSERT_EQ(y.dim(), 3);
470 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
471 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
472 ASSERT_TRUE(torch::allclose(
473 indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
474 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
475 }
476
TEST_F(ModulesTest,FractionalMaxPool3d)477 TEST_F(ModulesTest, FractionalMaxPool3d) {
478 FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
479 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
480 auto y = model(x);
481 torch::Tensor s = y.sum();
482
483 s.backward();
484 ASSERT_EQ(y.ndimension(), 4);
485 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
486 ASSERT_EQ(s.ndimension(), 0);
487 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
488 }
489
TEST_F(ModulesTest,FractionalMaxPool3dReturnIndices)490 TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) {
491 FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
492 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
493 auto [y, indices] = model->forward_with_indices(x);
494
495 ASSERT_EQ(y.dim(), 4);
496 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
497 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
498
499 ASSERT_TRUE(torch::allclose(
500 indices,
501 torch::tensor(
502 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
503 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
504 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
505 }
506
TEST_F(ModulesTest,LPPool1d)507 TEST_F(ModulesTest, LPPool1d) {
508 int norm_type = 2;
509 int stride = 2;
510 int kernel_size = 3;
511
512 LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride));
513 auto x = torch::ones({1, 1, 5});
514 auto y = model(x);
515 auto expected =
516 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
517 kernel_size)
518 .pow(1. / norm_type);
519
520 ASSERT_EQ(y.ndimension(), 3);
521 ASSERT_TRUE(torch::allclose(y, expected));
522 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
523 }
524
TEST_F(ModulesTest,LPPool2d)525 TEST_F(ModulesTest, LPPool2d) {
526 int norm_type = 2;
527 int stride = 2;
528 std::vector<int64_t> kernel_size({2, 3});
529
530 LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
531 auto x = torch::ones({1, 1, 2, 5});
532 auto y = model(x);
533 auto expected =
534 (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
535 (kernel_size[0] * kernel_size[1]))
536 .pow(1. / norm_type);
537
538 ASSERT_EQ(y.ndimension(), 4);
539 ASSERT_TRUE(torch::allclose(y, expected));
540 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
541 }
542
TEST_F(ModulesTest,LPPool3d)543 TEST_F(ModulesTest, LPPool3d) {
544 int norm_type = 2;
545 int stride = 2;
546 std::vector<int64_t> kernel_size({1, 2, 3});
547
548 LPPool3d model(LPPool3dOptions(norm_type, kernel_size).stride(stride));
549 auto x = torch::ones({1, 1, 1, 2, 5});
550 auto y = model(x);
551 auto expected =
552 (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
553 (kernel_size[0] * kernel_size[1] * kernel_size[2]))
554 .pow(1. / norm_type);
555
556 ASSERT_EQ(y.ndimension(), 5);
557 ASSERT_TRUE(torch::allclose(y, expected));
558 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
559 }
560
TEST_F(ModulesTest,Identity)561 TEST_F(ModulesTest, Identity) {
562 Identity identity;
563 auto input = torch::tensor(
564 {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
565 auto output = identity->forward(input);
566 auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
567 auto s = output.sum();
568 s.backward();
569
570 ASSERT_TRUE(torch::equal(output, expected));
571 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
572 }
573
TEST_F(ModulesTest,Flatten)574 TEST_F(ModulesTest, Flatten) {
575 Flatten flatten;
576 auto input = torch::tensor(
577 {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
578 auto output = flatten->forward(input);
579 auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat);
580 auto s = output.sum();
581
582 s.backward();
583 ASSERT_TRUE(torch::equal(output, expected));
584 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
585
586 // Testing with optional arguments start_dim and end_dim
587 Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3));
588 input = torch::tensor(
589 {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
590 {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}},
591 torch::dtype(torch::kFloat)
592 .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2)
593
594 output = flatten_optional_dims->forward(input);
595 expected = torch::tensor(
596 {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}},
597 torch::kFloat); // Tensor with sizes (2, 2, 4)
598
599 s = output.sum();
600 s.backward();
601 ASSERT_TRUE(torch::equal(output, expected));
602 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
603 }
604
TEST_F(ModulesTest,Unflatten)605 TEST_F(ModulesTest, Unflatten) {
606 // Non-named tensor
607 Unflatten unflatten(UnflattenOptions(0, {2, 2}));
608 auto output = unflatten->forward(torch::tensor({1, 2, 3, 4}));
609 auto expected = torch::tensor({{1, 2}, {3, 4}});
610 ASSERT_TRUE(torch::equal(output, expected));
611
612 // Named tensor
613 auto make_dimnames = [](std::vector<std::string> names) {
614 std::vector<torch::Dimname> dimnames;
615 // NOLINTNEXTLINE(performance-for-range-copy)
616 for (auto name : names) {
617 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
618 dimnames.push_back(
619 torch::Dimname::fromSymbol(torch::Symbol::dimname(name)));
620 }
621 return dimnames;
622 };
623
624 unflatten = Unflatten(UnflattenOptions(
625 "B",
626 {std::pair<std::string, int64_t>{"B1", 2},
627 std::pair<std::string, int64_t>{"B2", 2}}));
628 output = unflatten->forward(
629 torch::tensor({{1, 2, 3, 4}}).refine_names(make_dimnames({"A", "B"})));
630 expected = torch::tensor({{{1, 2}, {3, 4}}})
631 .refine_names(make_dimnames({"A", "B1", "B2"}));
632 ASSERT_TRUE(torch::equal(output, expected));
633 }
634
TEST_F(ModulesTest,AdaptiveMaxPool1d)635 TEST_F(ModulesTest, AdaptiveMaxPool1d) {
636 AdaptiveMaxPool1d model(3);
637 auto x = torch::tensor(
638 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
639 auto y = model(x);
640 torch::Tensor s = y.sum();
641
642 s.backward();
643 ASSERT_EQ(y.ndimension(), 3);
644 ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
645 ASSERT_EQ(s.ndimension(), 0);
646 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
647 }
648
TEST_F(ModulesTest,AdaptiveMaxPool1dReturnIndices)649 TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) {
650 AdaptiveMaxPool1d model(3);
651 auto x = torch::tensor(
652 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
653 auto [y, indices] = model->forward_with_indices(x);
654
655 ASSERT_EQ(y.dim(), 3);
656 ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
657 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
658 ASSERT_TRUE(
659 torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong)));
660 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 3}));
661 }
662
TEST_F(ModulesTest,AdaptiveMaxPool2dEven)663 TEST_F(ModulesTest, AdaptiveMaxPool2dEven) {
664 AdaptiveMaxPool2d model(3);
665 auto x = torch::arange(0., 50);
666 x.resize_({2, 5, 5}).set_requires_grad(true);
667 auto y = model(x);
668 torch::Tensor s = y.sum();
669
670 s.backward();
671 ASSERT_EQ(y.ndimension(), 3);
672 ASSERT_TRUE(torch::allclose(
673 y,
674 torch::tensor(
675 {
676 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
677 {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
678 },
679 torch::kFloat)));
680 ASSERT_EQ(s.ndimension(), 0);
681 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
682 }
683
TEST_F(ModulesTest,AdaptiveMaxPool2dUneven)684 TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) {
685 AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
686 auto x = torch::arange(0., 40);
687 x.resize_({2, 5, 4}).set_requires_grad(true);
688 auto y = model(x);
689 torch::Tensor s = y.sum();
690
691 s.backward();
692 ASSERT_EQ(y.ndimension(), 3);
693 ASSERT_TRUE(torch::allclose(
694 y,
695 torch::tensor(
696 {
697 {{5, 7}, {13, 15}, {17, 19}},
698 {{25, 27}, {33, 35}, {37, 39}},
699 },
700 torch::kFloat)));
701 ASSERT_EQ(s.ndimension(), 0);
702 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
703 }
704
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesEven)705 TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) {
706 AdaptiveMaxPool2d model(3);
707 auto x = torch::arange(0., 50);
708 x.resize_({2, 5, 5}).set_requires_grad(true);
709 auto [y, indices] = model->forward_with_indices(x);
710 torch::Tensor s = y.sum();
711
712 s.backward();
713 ASSERT_EQ(s.ndimension(), 0);
714
715 ASSERT_EQ(y.ndimension(), 3);
716 ASSERT_TRUE(torch::allclose(
717 y,
718 torch::tensor(
719 {
720 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
721 {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
722 },
723 torch::kFloat)));
724 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
725
726 ASSERT_EQ(indices.ndimension(), 3);
727 ASSERT_TRUE(torch::allclose(
728 indices,
729 torch::tensor(
730 {
731 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
732 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
733 },
734 torch::kLong)));
735 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 3}));
736 }
737
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesUneven)738 TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) {
739 AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
740 auto x = torch::arange(0., 40);
741 x.resize_({2, 5, 4}).set_requires_grad(true);
742 auto [y, indices] = model->forward_with_indices(x);
743 torch::Tensor s = y.sum();
744
745 s.backward();
746 ASSERT_EQ(s.ndimension(), 0);
747
748 ASSERT_EQ(y.ndimension(), 3);
749 ASSERT_TRUE(torch::allclose(
750 y,
751 torch::tensor(
752 {
753 {{5, 7}, {13, 15}, {17, 19}},
754 {{25, 27}, {33, 35}, {37, 39}},
755 },
756 torch::kFloat)));
757 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
758
759 ASSERT_EQ(indices.ndimension(), 3);
760 ASSERT_TRUE(torch::allclose(
761 indices,
762 torch::tensor(
763 {
764 {{5, 7}, {13, 15}, {17, 19}},
765 {{5, 7}, {13, 15}, {17, 19}},
766 },
767 torch::kLong)));
768 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 2}));
769 }
770
TEST_F(ModulesTest,AdaptiveMaxPool3d)771 TEST_F(ModulesTest, AdaptiveMaxPool3d) {
772 AdaptiveMaxPool3d model(3);
773 auto x = torch::arange(0., 64);
774 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
775 auto y = model(x);
776 torch::Tensor s = y.sum();
777
778 s.backward();
779 ASSERT_EQ(s.ndimension(), 0);
780
781 ASSERT_EQ(y.ndimension(), 4);
782 ASSERT_TRUE(torch::allclose(
783 y,
784 torch::tensor(
785 {
786 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
787 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
788 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
789 },
790 torch::kFloat)));
791 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
792 }
793
TEST_F(ModulesTest,AdaptiveMaxPool3dReturnIndices)794 TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) {
795 AdaptiveMaxPool3d model(3);
796 auto x = torch::arange(0., 64);
797 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
798 auto [y, indices] = model->forward_with_indices(x);
799 torch::Tensor s = y.sum();
800
801 s.backward();
802 ASSERT_EQ(s.ndimension(), 0);
803
804 ASSERT_EQ(y.ndimension(), 4);
805 ASSERT_TRUE(torch::allclose(
806 y,
807 torch::tensor(
808 {
809 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
810 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
811 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
812 },
813 torch::kFloat)));
814 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
815
816 ASSERT_EQ(indices.ndimension(), 4);
817 ASSERT_TRUE(torch::allclose(
818 indices,
819 torch::tensor(
820 {
821 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
822 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
823 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
824 },
825 torch::kLong)));
826 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
827 }
828
TEST_F(ModulesTest,AdaptiveAvgPool1d)829 TEST_F(ModulesTest, AdaptiveAvgPool1d) {
830 AdaptiveAvgPool1d model(3);
831 auto x = torch::tensor(
832 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
833 auto y = model(x);
834 torch::Tensor s = y.sum();
835
836 s.backward();
837 ASSERT_EQ(s.ndimension(), 0);
838
839 ASSERT_EQ(y.ndimension(), 3);
840 ASSERT_TRUE(
841 torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat)));
842 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
843 }
844
TEST_F(ModulesTest,AdaptiveAvgPool2dEven)845 TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
846 AdaptiveAvgPool2d model(3);
847 auto x = torch::arange(0., 50);
848 x.resize_({2, 5, 5}).set_requires_grad(true);
849 auto y = model(x);
850 torch::Tensor s = y.sum();
851
852 s.backward();
853 ASSERT_EQ(s.ndimension(), 0);
854
855 ASSERT_EQ(y.ndimension(), 3);
856 ASSERT_TRUE(torch::allclose(
857 y,
858 torch::tensor(
859 {
860 {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}},
861 {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}},
862 },
863 torch::kFloat)));
864 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
865 }
866
TEST_F(ModulesTest,AdaptiveAvgPool2dUneven)867 TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
868 AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
869 auto x = torch::arange(0., 40);
870 x.resize_({2, 5, 4}).set_requires_grad(true);
871 auto y = model(x);
872 torch::Tensor s = y.sum();
873
874 s.backward();
875 ASSERT_EQ(s.ndimension(), 0);
876
877 ASSERT_EQ(y.ndimension(), 3);
878 ASSERT_TRUE(torch::allclose(
879 y,
880 torch::tensor(
881 {
882 {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}},
883 {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}},
884 },
885 torch::kFloat)));
886 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
887 }
888
TEST_F(ModulesTest,AdaptiveAvgPool3d)889 TEST_F(ModulesTest, AdaptiveAvgPool3d) {
890 AdaptiveAvgPool3d model(3);
891 auto x = torch::arange(0., 64);
892 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
893 auto y = model(x);
894 torch::Tensor s = y.sum();
895
896 s.backward();
897 ASSERT_EQ(s.ndimension(), 0);
898
899 ASSERT_EQ(y.ndimension(), 4);
900 ASSERT_TRUE(torch::allclose(
901 y,
902 torch::tensor(
903 {
904 {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}},
905 {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}},
906 {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}},
907 },
908 torch::kFloat)));
909 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
910 }
911
TEST_F(ModulesTest,MaxUnpool1d)912 TEST_F(ModulesTest, MaxUnpool1d) {
913 auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
914 auto x = torch::tensor(
915 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
916 auto model = MaxUnpool1d{3};
917 auto y = model->forward(x, indices);
918
919 ASSERT_EQ(y.dim(), 3);
920 ASSERT_TRUE(torch::allclose(
921 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
922 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
923
924 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
925 x = torch::tensor(
926 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
927 model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)};
928 y = model->forward(x, indices, std::vector<int64_t>({1, 1, 5}));
929
930 ASSERT_EQ(y.dim(), 3);
931 ASSERT_TRUE(
932 torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
933 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
934 }
935
TEST_F(ModulesTest,MaxPool1d_MaxUnpool1d)936 TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) {
937 MaxPool1d pool{MaxPool1dOptions(2).stride(2)};
938 MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)};
939 auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat);
940 auto [output, indices] = pool->forward_with_indices(input);
941 ASSERT_TRUE(torch::allclose(
942 unpool(output, indices),
943 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
944
945 // Example showcasing the use of output_size
946 input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat);
947 std::tie(output, indices) = pool->forward_with_indices(input);
948 ASSERT_TRUE(torch::allclose(
949 unpool(output, indices, input.sizes().vec()),
950 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat)));
951 ASSERT_TRUE(torch::allclose(
952 unpool(output, indices),
953 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
954 }
955
TEST_F(ModulesTest,MaxUnpool2d)956 TEST_F(ModulesTest, MaxUnpool2d) {
957 auto indices = torch::tensor(
958 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
959 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
960 torch::kLong);
961 auto x = torch::tensor(
962 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
963 {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
964 torch::dtype(torch::kFloat).requires_grad(true));
965 auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)};
966 auto y = model->forward(x, indices);
967
968 ASSERT_EQ(y.dim(), 4);
969 ASSERT_TRUE(torch::allclose(
970 y,
971 torch::tensor(
972 {{{{0, 0, 0, 0, 0},
973 {0, 6, 0, 8, 9},
974 {0, 0, 0, 0, 0},
975 {0, 16, 0, 18, 19},
976 {0, 21, 0, 23, 24}}},
977 {{{0, 0, 0, 0, 0},
978 {0, 31, 0, 33, 34},
979 {0, 0, 0, 0, 0},
980 {0, 41, 0, 43, 44},
981 {0, 46, 0, 48, 49}}}},
982 torch::kFloat)));
983 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
984 }
985
TEST_F(ModulesTest,MaxPool2d_MaxUnpool2d)986 TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) {
987 MaxPool2d pool{MaxPool2dOptions(2).stride(2)};
988 MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)};
989 auto input = torch::tensor(
990 {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}},
991 torch::kFloat);
992 auto [output, indices] = pool->forward_with_indices(input);
993 ASSERT_TRUE(torch::allclose(
994 unpool(output, indices),
995 torch::tensor(
996 {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}},
997 torch::kFloat)));
998
999 ASSERT_TRUE(torch::allclose(
1000 unpool(output, indices, std::vector<int64_t>{1, 1, 5, 5}),
1001 torch::tensor(
1002 {{{{0, 0, 0, 0, 0},
1003 {6, 0, 8, 0, 0},
1004 {0, 0, 0, 14, 0},
1005 {16, 0, 0, 0, 0},
1006 {0, 0, 0, 0, 0}}}},
1007 torch::kFloat)));
1008 }
1009
TEST_F(ModulesTest,MaxUnpool3d)1010 TEST_F(ModulesTest, MaxUnpool3d) {
1011 auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1012 auto x = torch::tensor(
1013 {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1014 auto model = MaxUnpool3d{3};
1015 auto y = model->forward(x, indices);
1016
1017 ASSERT_EQ(y.dim(), 5);
1018 ASSERT_TRUE(torch::allclose(
1019 y,
1020 torch::tensor(
1021 {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1022 {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1023 {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1024 torch::kFloat)));
1025 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1026 }
1027
TEST_F(ModulesTest,MaxUnpool3dOutputSize)1028 TEST_F(ModulesTest, MaxUnpool3dOutputSize) {
1029 auto indices = torch::tensor(
1030 {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong);
1031 auto x = torch::tensor(
1032 {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}},
1033 torch::dtype(torch::kFloat).requires_grad(true));
1034 auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)};
1035 auto y = model->forward(x, indices, std::vector<int64_t>({1, 1, 4, 4, 4}));
1036
1037 ASSERT_EQ(y.dim(), 5);
1038 ASSERT_TRUE(torch::allclose(
1039 y,
1040 torch::tensor(
1041 {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1042 {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}},
1043 {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1044 {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}},
1045 torch::kFloat)));
1046 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 4, 4, 4}));
1047 }
1048
TEST_F(ModulesTest,MaxPool3d_MaxUnpool3d)1049 TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) {
1050 MaxPool3d pool{MaxPool3dOptions(3).stride(2)};
1051 MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)};
1052 auto input = torch::randn({20, 16, 51, 33, 15});
1053 auto [output, indices] = pool->forward_with_indices(input);
1054 auto unpooled_output = unpool(output, indices);
1055 ASSERT_EQ(
1056 unpooled_output.sizes(), std::vector<int64_t>({20, 16, 51, 33, 15}));
1057 }
1058
TEST_F(ModulesTest,Linear)1059 TEST_F(ModulesTest, Linear) {
1060 {
1061 Linear model(5, 2);
1062 auto x = torch::randn({10, 5}, torch::requires_grad());
1063 auto y = model(x);
1064 torch::Tensor s = y.sum();
1065
1066 s.backward();
1067 ASSERT_EQ(y.ndimension(), 2);
1068 ASSERT_EQ(s.ndimension(), 0);
1069 ASSERT_EQ(y.size(0), 10);
1070 ASSERT_EQ(y.size(1), 2);
1071
1072 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1073
1074 auto y_exp = torch::addmm(model->bias, x, model->weight.t());
1075 ASSERT_TRUE(torch::allclose(y, y_exp));
1076 }
1077 {
1078 Linear model(LinearOptions(5, 2).bias(false));
1079 auto x = torch::randn({10, 5}, torch::requires_grad());
1080 auto y = model(x);
1081 torch::Tensor s = y.sum();
1082
1083 s.backward();
1084 ASSERT_EQ(y.ndimension(), 2);
1085 ASSERT_EQ(s.ndimension(), 0);
1086 ASSERT_EQ(y.size(0), 10);
1087 ASSERT_EQ(y.size(1), 2);
1088
1089 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1090
1091 auto y_exp = torch::mm(x, model->weight.t());
1092 ASSERT_TRUE(torch::allclose(y, y_exp));
1093 }
1094 }
1095
TEST_F(ModulesTest,LocalResponseNorm)1096 TEST_F(ModulesTest, LocalResponseNorm) {
1097 {
1098 LocalResponseNorm model(LocalResponseNormOptions(2));
1099 const auto x =
1100 torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2});
1101 auto y = model(x);
1102 const auto y_exp = torch::tensor(
1103 {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}},
1104
1105 {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}},
1106
1107 {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}},
1108
1109 {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}},
1110
1111 {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}},
1112
1113 {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}},
1114 torch::kFloat);
1115 torch::Tensor s = y.sum();
1116
1117 s.backward();
1118 ASSERT_EQ(y.ndimension(), 4);
1119 ASSERT_EQ(s.ndimension(), 0);
1120 ASSERT_EQ(y.sizes(), x.sizes());
1121 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1122 }
1123 }
1124
TEST_F(ModulesTest,LayerNorm)1125 TEST_F(ModulesTest, LayerNorm) {
1126 LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
1127 auto x = torch::randn({2, 2}, torch::requires_grad());
1128 auto y = model(x);
1129 auto y_exp = torch::layer_norm(x, {2, 2}, model->weight, model->bias, 2e-5);
1130 torch::Tensor s = y.sum();
1131
1132 s.backward();
1133 ASSERT_EQ(y.ndimension(), 2);
1134 ASSERT_EQ(s.ndimension(), 0);
1135 for (const auto i : c10::irange(2)) {
1136 ASSERT_EQ(y.size(i), 2);
1137 }
1138
1139 ASSERT_EQ(model->weight.grad().numel(), 2 * 2);
1140 ASSERT_TRUE(torch::allclose(y, y_exp));
1141 }
1142
TEST_F(ModulesTest,GroupNorm)1143 TEST_F(ModulesTest, GroupNorm) {
1144 GroupNorm model(GroupNormOptions(2, 2).eps(2e-5));
1145 auto x = torch::randn({2, 2}, torch::requires_grad());
1146 auto y = model(x);
1147 auto y_exp = torch::group_norm(x, 2, model->weight, model->bias, 2e-5);
1148 torch::Tensor s = y.sum();
1149
1150 s.backward();
1151 ASSERT_EQ(y.ndimension(), 2);
1152 ASSERT_EQ(s.ndimension(), 0);
1153 for (const auto i : c10::irange(2)) {
1154 ASSERT_EQ(y.size(i), 2);
1155 }
1156
1157 ASSERT_EQ(model->weight.grad().numel(), 2);
1158 ASSERT_TRUE(torch::allclose(y, y_exp));
1159 }
1160
TEST_F(ModulesTest,Bilinear)1161 TEST_F(ModulesTest, Bilinear) {
1162 Bilinear model(5, 3, 2);
1163 auto x1 = torch::randn({10, 5}, torch::requires_grad());
1164 auto x2 = torch::randn({10, 3}, torch::requires_grad());
1165 auto y = model(x1, x2);
1166 torch::Tensor s = y.sum();
1167
1168 s.backward();
1169 ASSERT_EQ(y.ndimension(), 2);
1170 ASSERT_EQ(s.ndimension(), 0);
1171 ASSERT_EQ(y.size(0), 10);
1172 ASSERT_EQ(y.size(1), 2);
1173
1174 ASSERT_EQ(model->weight.grad().numel(), 2 * 5 * 3);
1175 }
1176
TEST_F(ModulesTest,Fold)1177 TEST_F(ModulesTest, Fold) {
1178 {
1179 Fold model(FoldOptions({3, 2}, {2, 2}));
1180 auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::requires_grad());
1181 auto output = model(input);
1182 auto expected = torch::tensor(
1183 {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1184 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1185 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1186 torch::kFloat);
1187 auto s = output.sum();
1188 s.backward();
1189
1190 ASSERT_EQ(s.ndimension(), 0);
1191 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1192 ASSERT_TRUE(output.allclose(expected));
1193 }
1194 {
1195 // input wrong dimension
1196 Fold model(FoldOptions({8, 8}, {3, 3}));
1197 ASSERT_THROWS_WITH(
1198 model(torch::randn({1, 3, 16, 16})),
1199 "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported (got 4D)");
1200 }
1201 }
1202
TEST_F(ModulesTest,Unfold)1203 TEST_F(ModulesTest, Unfold) {
1204 {
1205 Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2));
1206 auto input =
1207 torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3});
1208 auto output = model(input);
1209 auto expected = torch::tensor(
1210 {{{0.0, 0.0, 0.0, 6.0},
1211 {0.0, 0.0, 5.0, 7.0},
1212 {0.0, 3.0, 0.0, 0.0},
1213 {2.0, 4.0, 0.0, 0.0},
1214 {0.0, 0.0, 0.0, 12.0},
1215 {0.0, 0.0, 11.0, 13.0},
1216 {0.0, 9.0, 0.0, 0.0},
1217 {8.0, 10.0, 0.0, 0.0}}},
1218 torch::kFloat);
1219 auto s = output.sum();
1220 s.backward();
1221
1222 ASSERT_EQ(s.ndimension(), 0);
1223 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1224 ASSERT_TRUE(output.allclose(expected));
1225 }
1226 {
1227 // input wrong dimension
1228 Unfold model(UnfoldOptions({2, 4}));
1229 ASSERT_THROWS_WITH(
1230 model(torch::randn({1, 5, 2})),
1231 "Input Error: Only 4D input Tensors are supported (got 3D)");
1232 }
1233 {
1234 // calculated output shape is too small
1235 Unfold model(UnfoldOptions({2, 3}));
1236 ASSERT_THROWS_WITH(
1237 model(torch::randn({1, 2, 2, 2})),
1238 "Given input with spatial size (2, 2), kernel_size=(2, 3), "
1239 "dilation=(1, 1), padding=(0, 0), calculated shape of the array of "
1240 "sliding blocks as (1, 0), but its components must be at least one.");
1241 }
1242 }
1243
TEST_F(ModulesTest,SimpleContainer)1244 TEST_F(ModulesTest, SimpleContainer) {
1245 auto model = std::make_shared<SimpleContainer>();
1246 auto l1 = model->add(Linear(10, 3), "l1");
1247 auto l2 = model->add(Linear(3, 5), "l2");
1248 auto l3 = model->add(Linear(5, 100), "l3");
1249
1250 auto x = torch::randn({1000, 10}, torch::requires_grad());
1251 x = l1(x).clamp_min(0);
1252 x = l2(x).clamp_min(0);
1253 x = l3(x).clamp_min(0);
1254
1255 x.backward(torch::ones_like(x));
1256 ASSERT_EQ(x.ndimension(), 2);
1257 ASSERT_EQ(x.size(0), 1000);
1258 ASSERT_EQ(x.size(1), 100);
1259 ASSERT_EQ(x.min().item<float>(), 0);
1260 }
1261
TEST_F(ModulesTest,EmbeddingBasic)1262 TEST_F(ModulesTest, EmbeddingBasic) {
1263 const int64_t dict_size = 10;
1264 Embedding model(dict_size, 2);
1265 ASSERT_TRUE(model->named_parameters().contains("weight"));
1266 ASSERT_EQ(model->weight.ndimension(), 2);
1267 ASSERT_EQ(model->weight.size(0), dict_size);
1268 ASSERT_EQ(model->weight.size(1), 2);
1269
1270 // Cannot get gradients to change indices (input) - only for embedding
1271 // params
1272 auto x = torch::full({10}, dict_size - 1, torch::kInt64);
1273 auto y = model(x);
1274 torch::Tensor s = y.sum();
1275
1276 s.backward();
1277 ASSERT_EQ(y.ndimension(), 2);
1278 ASSERT_EQ(s.ndimension(), 0);
1279 ASSERT_EQ(y.size(0), 10);
1280 ASSERT_EQ(y.size(1), 2);
1281
1282 ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size);
1283 }
1284
TEST_F(ModulesTest,EmbeddingList)1285 TEST_F(ModulesTest, EmbeddingList) {
1286 Embedding model(6, 4);
1287 auto x = torch::full({2, 3}, 5, torch::kInt64);
1288 auto y = model(x);
1289 torch::Tensor s = y.sum();
1290
1291 s.backward();
1292 ASSERT_EQ(y.ndimension(), 3);
1293 ASSERT_EQ(y.size(0), 2);
1294 ASSERT_EQ(y.size(1), 3);
1295 ASSERT_EQ(y.size(2), 4);
1296 }
1297
TEST_F(ModulesTest,EmbeddingFromPretrained)1298 TEST_F(ModulesTest, EmbeddingFromPretrained) {
1299 auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1300 Embedding embedding = torch::nn::Embedding::from_pretrained(weight);
1301 auto input = torch::tensor({1}, torch::kLong);
1302 ASSERT_TRUE(torch::allclose(
1303 embedding(input), torch::tensor({4.0000, 5.1000, 6.3000})));
1304 }
1305
TEST_F(ModulesTest,EmbeddingBagFromPretrained)1306 TEST_F(ModulesTest, EmbeddingBagFromPretrained) {
1307 auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1308 EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight);
1309 auto input = torch::zeros({{1, 2}}, torch::kLong);
1310 input[0] = torch::tensor({1, 0});
1311 ASSERT_TRUE(torch::allclose(
1312 embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500})));
1313 }
1314
TEST_F(ModulesTest,AlphaDropout)1315 TEST_F(ModulesTest, AlphaDropout) {
1316 AlphaDropout alpha_dropout(0.5);
1317 torch::Tensor x = torch::ones(100, torch::requires_grad());
1318 torch::Tensor y = alpha_dropout(x);
1319
1320 y.backward(torch::ones_like(y));
1321
1322 ASSERT_EQ(y.ndimension(), 1);
1323 ASSERT_EQ(y.size(0), 100);
1324 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1325 ASSERT_GT(y.sum().item<float>(), 40); // Probably
1326
1327 alpha_dropout->eval();
1328 y = alpha_dropout(x);
1329
1330 ASSERT_EQ(y.sum().item<float>(), 100);
1331 }
1332
TEST_F(ModulesTest,FeatureAlphaDropout)1333 TEST_F(ModulesTest, FeatureAlphaDropout) {
1334 FeatureAlphaDropout feature_alpha_dropout(0.5);
1335 torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1336 torch::Tensor y = feature_alpha_dropout(x);
1337
1338 y.backward(torch::ones_like(y));
1339
1340 ASSERT_EQ(y.ndimension(), 2);
1341 ASSERT_EQ(y.size(0), 10);
1342 ASSERT_EQ(y.size(1), 10);
1343 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1344 ASSERT_GT(y.sum().item<float>(), 40); // Probably
1345
1346 feature_alpha_dropout->eval();
1347 y = feature_alpha_dropout(x);
1348
1349 ASSERT_EQ(y.sum().item<float>(), 100);
1350 }
1351
TEST_F(ModulesTest,Dropout)1352 TEST_F(ModulesTest, Dropout) {
1353 for (const auto inplace : {false, true}) {
1354 Dropout dropout(DropoutOptions(0.5).inplace(inplace));
1355 torch::Tensor x = torch::ones(100);
1356 if (!inplace) {
1357 x.requires_grad_(true);
1358 }
1359 torch::Tensor y = dropout(x);
1360
1361 ASSERT_EQ(y.ndimension(), 1);
1362 ASSERT_EQ(y.size(0), 100);
1363 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1364 ASSERT_GT(y.sum().item<float>(), 70); // Probably
1365 if (inplace) {
1366 ASSERT_TRUE(y.allclose(x));
1367 } else {
1368 y.backward(torch::ones_like(y));
1369 }
1370
1371 dropout->eval();
1372 y = dropout(torch::ones(100));
1373 ASSERT_EQ(y.sum().item<float>(), 100);
1374 }
1375 }
1376
TEST_F(ModulesTest,Dropout2d)1377 TEST_F(ModulesTest, Dropout2d) {
1378 auto p = 0.5;
1379 for (const auto inplace : {false, true}) {
1380 Dropout2d dropout(Dropout2dOptions(p).inplace(inplace));
1381 torch::Tensor x = torch::empty({50, 50, 2, 2}).fill_(1 - p);
1382 if (!inplace) {
1383 x.requires_grad_(true);
1384 }
1385 torch::Tensor y = dropout(x);
1386
1387 ASSERT_EQ(y.ndimension(), 4);
1388 ASSERT_EQ(y.size(0), 50);
1389 ASSERT_EQ(y.size(1), 50);
1390 ASSERT_EQ(y.size(2), 2);
1391 ASSERT_EQ(y.size(3), 2);
1392 ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1393
1394 if (inplace) {
1395 ASSERT_TRUE(y.allclose(x));
1396 } else {
1397 y.backward(torch::ones_like(y));
1398 }
1399
1400 dropout->eval();
1401 y = dropout(torch::ones({2, 2, 10, 10}));
1402 ASSERT_EQ(y.sum().item<float>(), 400);
1403 }
1404 }
1405
TEST_F(ModulesTest,Dropout3d)1406 TEST_F(ModulesTest, Dropout3d) {
1407 for (const auto inplace : {false, true}) {
1408 auto p = 0.5;
1409 Dropout3d dropout(Dropout3dOptions(p).inplace(inplace));
1410 torch::Tensor x = torch::empty({50, 50, 2, 2, 2}).fill_(1 - p);
1411 if (!inplace) {
1412 x.requires_grad_(true);
1413 }
1414 torch::Tensor y = dropout(x);
1415
1416 ASSERT_EQ(y.ndimension(), 5);
1417 ASSERT_EQ(y.size(0), 50);
1418 ASSERT_EQ(y.size(1), 50);
1419 ASSERT_EQ(y.size(2), 2);
1420 ASSERT_EQ(y.size(3), 2);
1421 ASSERT_EQ(y.size(4), 2);
1422 ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1423
1424 if (inplace) {
1425 ASSERT_TRUE(y.allclose(x));
1426 } else {
1427 y.backward(torch::ones_like(y));
1428 }
1429
1430 dropout->eval();
1431 y = dropout(torch::ones({4, 4, 5, 5}));
1432 ASSERT_EQ(y.sum().item<float>(), 400);
1433 }
1434 }
1435
TEST_F(ModulesTest,Parameters)1436 TEST_F(ModulesTest, Parameters) {
1437 auto model = std::make_shared<NestedModel>();
1438 auto parameters = model->named_parameters();
1439 ASSERT_EQ(parameters["param"].size(0), 3);
1440 ASSERT_EQ(parameters["param"].size(1), 2);
1441 ASSERT_EQ(parameters["param"].size(2), 21);
1442 ASSERT_EQ(parameters["l1.bias"].size(0), 20);
1443 ASSERT_EQ(parameters["l1.weight"].size(0), 20);
1444 ASSERT_EQ(parameters["l1.weight"].size(1), 5);
1445 ASSERT_EQ(parameters["test.l1.bias"].size(0), 3);
1446 ASSERT_EQ(parameters["test.l1.weight"].size(0), 3);
1447 ASSERT_EQ(parameters["test.l1.weight"].size(1), 10);
1448 ASSERT_EQ(parameters["test.l2.bias"].size(0), 5);
1449 ASSERT_EQ(parameters["test.l2.weight"].size(0), 5);
1450 ASSERT_EQ(parameters["test.l2.weight"].size(1), 3);
1451 ASSERT_EQ(parameters["test.l3.bias"].size(0), 100);
1452 ASSERT_EQ(parameters["test.l3.weight"].size(0), 100);
1453 ASSERT_EQ(parameters["test.l3.weight"].size(1), 5);
1454 }
1455
TEST_F(ModulesTest,FunctionalCallsSuppliedFunction)1456 TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
1457 bool was_called = false;
1458 auto functional = Functional([&was_called](torch::Tensor input) {
1459 was_called = true;
1460 return input;
1461 });
1462 auto output = functional(torch::ones(5, torch::requires_grad()));
1463 ASSERT_TRUE(was_called);
1464 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1465
1466 was_called = false;
1467 // Use the call operator overload here.
1468 output = functional(torch::ones(5, torch::requires_grad()));
1469 ASSERT_TRUE(was_called);
1470 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1471 }
1472
TEST_F(ModulesTest,FunctionalWithTorchFunction)1473 TEST_F(ModulesTest, FunctionalWithTorchFunction) {
1474 auto functional = Functional(torch::relu);
1475 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1476 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1477 ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
1478 }
1479
TEST_F(ModulesTest,FunctionalArgumentBinding)1480 TEST_F(ModulesTest, FunctionalArgumentBinding) {
1481 auto functional =
1482 Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
1483 ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
1484 }
1485
TEST_F(ModulesTest,BatchNorm1dStateful)1486 TEST_F(ModulesTest, BatchNorm1dStateful) {
1487 BatchNorm1d bn(5);
1488
1489 ASSERT_TRUE(bn->options.track_running_stats());
1490
1491 ASSERT_TRUE(bn->running_mean.defined());
1492 ASSERT_EQ(bn->running_mean.dim(), 1);
1493 ASSERT_EQ(bn->running_mean.size(0), 5);
1494
1495 ASSERT_TRUE(bn->running_var.defined());
1496 ASSERT_EQ(bn->running_var.dim(), 1);
1497 ASSERT_EQ(bn->running_var.size(0), 5);
1498
1499 ASSERT_TRUE(bn->num_batches_tracked.defined());
1500 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1501
1502 ASSERT_TRUE(bn->options.affine());
1503
1504 ASSERT_TRUE(bn->weight.defined());
1505 ASSERT_EQ(bn->weight.dim(), 1);
1506 ASSERT_EQ(bn->weight.size(0), 5);
1507
1508 ASSERT_TRUE(bn->bias.defined());
1509 ASSERT_EQ(bn->bias.dim(), 1);
1510 ASSERT_EQ(bn->bias.size(0), 5);
1511 }
1512
TEST_F(ModulesTest,BatchNorm1dStateless)1513 TEST_F(ModulesTest, BatchNorm1dStateless) {
1514 BatchNorm1d bn(
1515 BatchNorm1dOptions(5).track_running_stats(false).affine(false));
1516
1517 ASSERT_FALSE(bn->running_mean.defined());
1518 ASSERT_FALSE(bn->running_var.defined());
1519 ASSERT_FALSE(bn->num_batches_tracked.defined());
1520 ASSERT_FALSE(bn->weight.defined());
1521 ASSERT_FALSE(bn->bias.defined());
1522 }
1523
TEST_F(ModulesTest,BatchNorm1d)1524 TEST_F(ModulesTest, BatchNorm1d) {
1525 BatchNorm1d bn(5);
1526 bn->eval();
1527
1528 auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1529 auto output = bn->forward(input);
1530 auto expected = torch::tensor(
1531 {{{0.0000, 1.0000},
1532 {2.0000, 3.0000},
1533 {4.0000, 5.0000},
1534 {6.0000, 7.0000},
1535 {8.0000, 9.0000}},
1536 {{10.0000, 10.9999},
1537 {11.9999, 12.9999},
1538 {13.9999, 14.9999},
1539 {15.9999, 16.9999},
1540 {17.9999, 18.9999}}});
1541 ASSERT_TRUE(output.allclose(expected));
1542 auto s = output.sum();
1543 s.backward();
1544
1545 ASSERT_EQ(input.sizes(), input.grad().sizes());
1546 }
1547
TEST_F(ModulesTest,BatchNorm2dStateful)1548 TEST_F(ModulesTest, BatchNorm2dStateful) {
1549 BatchNorm2d bn(5);
1550
1551 ASSERT_TRUE(bn->options.track_running_stats());
1552
1553 ASSERT_TRUE(bn->running_mean.defined());
1554 ASSERT_EQ(bn->running_mean.dim(), 1);
1555 ASSERT_EQ(bn->running_mean.size(0), 5);
1556
1557 ASSERT_TRUE(bn->running_var.defined());
1558 ASSERT_EQ(bn->running_var.dim(), 1);
1559 ASSERT_EQ(bn->running_var.size(0), 5);
1560
1561 ASSERT_TRUE(bn->num_batches_tracked.defined());
1562 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1563
1564 ASSERT_TRUE(bn->options.affine());
1565
1566 ASSERT_TRUE(bn->weight.defined());
1567 ASSERT_EQ(bn->weight.dim(), 1);
1568 ASSERT_EQ(bn->weight.size(0), 5);
1569
1570 ASSERT_TRUE(bn->bias.defined());
1571 ASSERT_EQ(bn->bias.dim(), 1);
1572 ASSERT_EQ(bn->bias.size(0), 5);
1573 }
1574
TEST_F(ModulesTest,BatchNorm2dStateless)1575 TEST_F(ModulesTest, BatchNorm2dStateless) {
1576 BatchNorm2d bn(
1577 BatchNorm2dOptions(5).track_running_stats(false).affine(false));
1578
1579 ASSERT_FALSE(bn->running_mean.defined());
1580 ASSERT_FALSE(bn->running_var.defined());
1581 ASSERT_FALSE(bn->num_batches_tracked.defined());
1582 ASSERT_FALSE(bn->weight.defined());
1583 ASSERT_FALSE(bn->bias.defined());
1584 }
1585
TEST_F(ModulesTest,BatchNorm2d)1586 TEST_F(ModulesTest, BatchNorm2d) {
1587 BatchNorm2d bn(5);
1588 bn->eval();
1589
1590 auto input =
1591 torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1592 auto output = bn->forward(input);
1593 auto expected = torch::tensor(
1594 {{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1595 {{4.0000, 5.0000}, {6.0000, 7.0000}},
1596 {{8.0000, 9.0000}, {10.0000, 10.9999}},
1597 {{11.9999, 12.9999}, {13.9999, 14.9999}},
1598 {{15.9999, 16.9999}, {17.9999, 18.9999}}},
1599 {{{19.9999, 20.9999}, {21.9999, 22.9999}},
1600 {{23.9999, 24.9999}, {25.9999, 26.9999}},
1601 {{27.9999, 28.9999}, {29.9998, 30.9998}},
1602 {{31.9998, 32.9998}, {33.9998, 34.9998}},
1603 {{35.9998, 36.9998}, {37.9998, 38.9998}}}});
1604 ASSERT_TRUE(output.allclose(expected));
1605 auto s = output.sum();
1606 s.backward();
1607
1608 ASSERT_EQ(input.sizes(), input.grad().sizes());
1609 }
1610
TEST_F(ModulesTest,BatchNorm3dStateful)1611 TEST_F(ModulesTest, BatchNorm3dStateful) {
1612 BatchNorm3d bn(5);
1613
1614 ASSERT_TRUE(bn->options.track_running_stats());
1615
1616 ASSERT_TRUE(bn->running_mean.defined());
1617 ASSERT_EQ(bn->running_mean.dim(), 1);
1618 ASSERT_EQ(bn->running_mean.size(0), 5);
1619
1620 ASSERT_TRUE(bn->running_var.defined());
1621 ASSERT_EQ(bn->running_var.dim(), 1);
1622 ASSERT_EQ(bn->running_var.size(0), 5);
1623
1624 ASSERT_TRUE(bn->num_batches_tracked.defined());
1625 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1626
1627 ASSERT_TRUE(bn->options.affine());
1628
1629 ASSERT_TRUE(bn->weight.defined());
1630 ASSERT_EQ(bn->weight.dim(), 1);
1631 ASSERT_EQ(bn->weight.size(0), 5);
1632
1633 ASSERT_TRUE(bn->bias.defined());
1634 ASSERT_EQ(bn->bias.dim(), 1);
1635 ASSERT_EQ(bn->bias.size(0), 5);
1636 }
1637
TEST_F(ModulesTest,BatchNorm3dStateless)1638 TEST_F(ModulesTest, BatchNorm3dStateless) {
1639 BatchNorm3d bn(
1640 BatchNorm3dOptions(5).track_running_stats(false).affine(false));
1641
1642 ASSERT_FALSE(bn->running_mean.defined());
1643 ASSERT_FALSE(bn->running_var.defined());
1644 ASSERT_FALSE(bn->num_batches_tracked.defined());
1645 ASSERT_FALSE(bn->weight.defined());
1646 ASSERT_FALSE(bn->bias.defined());
1647 }
1648
TEST_F(ModulesTest,BatchNorm3d)1649 TEST_F(ModulesTest, BatchNorm3d) {
1650 BatchNorm3d bn(5);
1651 bn->eval();
1652
1653 auto input =
1654 torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1655 auto output = bn->forward(input);
1656 auto expected = torch::tensor(
1657 {{{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1658 {{4.0000, 5.0000}, {6.0000, 7.0000}}},
1659 {{{8.0000, 9.0000}, {10.0000, 10.9999}},
1660 {{11.9999, 12.9999}, {13.9999, 14.9999}}},
1661 {{{15.9999, 16.9999}, {17.9999, 18.9999}},
1662 {{19.9999, 20.9999}, {21.9999, 22.9999}}},
1663 {{{23.9999, 24.9999}, {25.9999, 26.9999}},
1664 {{27.9999, 28.9999}, {29.9998, 30.9998}}},
1665 {{{31.9998, 32.9998}, {33.9998, 34.9998}},
1666 {{35.9998, 36.9998}, {37.9998, 38.9998}}}},
1667 {{{{39.9998, 40.9998}, {41.9998, 42.9998}},
1668 {{43.9998, 44.9998}, {45.9998, 46.9998}}},
1669 {{{47.9998, 48.9998}, {49.9997, 50.9997}},
1670 {{51.9997, 52.9997}, {53.9997, 54.9997}}},
1671 {{{55.9997, 56.9997}, {57.9997, 58.9997}},
1672 {{59.9997, 60.9997}, {61.9997, 62.9997}}},
1673 {{{63.9997, 64.9997}, {65.9997, 66.9997}},
1674 {{67.9997, 68.9997}, {69.9996, 70.9996}}},
1675 {{{71.9996, 72.9996}, {73.9996, 74.9996}},
1676 {{75.9996, 76.9996}, {77.9996, 78.9996}}}}});
1677 ASSERT_TRUE(output.allclose(expected));
1678 auto s = output.sum();
1679 s.backward();
1680
1681 ASSERT_EQ(input.sizes(), input.grad().sizes());
1682 }
1683
TEST_F(ModulesTest,InstanceNorm1dStateful)1684 TEST_F(ModulesTest, InstanceNorm1dStateful) {
1685 InstanceNorm1d instance_norm(
1686 InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
1687
1688 ASSERT_TRUE(instance_norm->options.track_running_stats());
1689
1690 ASSERT_TRUE(instance_norm->running_mean.defined());
1691 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1692 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1693
1694 ASSERT_TRUE(instance_norm->running_var.defined());
1695 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1696 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1697
1698 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1699 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1700
1701 ASSERT_TRUE(instance_norm->options.affine());
1702
1703 ASSERT_TRUE(instance_norm->weight.defined());
1704 ASSERT_EQ(instance_norm->weight.dim(), 1);
1705 ASSERT_EQ(instance_norm->weight.size(0), 5);
1706
1707 ASSERT_TRUE(instance_norm->bias.defined());
1708 ASSERT_EQ(instance_norm->bias.dim(), 1);
1709 ASSERT_EQ(instance_norm->bias.size(0), 5);
1710 }
1711
TEST_F(ModulesTest,InstanceNorm1dStateless)1712 TEST_F(ModulesTest, InstanceNorm1dStateless) {
1713 InstanceNorm1d instance_norm(
1714 InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
1715
1716 ASSERT_FALSE(instance_norm->running_mean.defined());
1717 ASSERT_FALSE(instance_norm->running_var.defined());
1718 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1719 ASSERT_FALSE(instance_norm->weight.defined());
1720 ASSERT_FALSE(instance_norm->bias.defined());
1721 }
1722
TEST_F(ModulesTest,InstanceNorm1d)1723 TEST_F(ModulesTest, InstanceNorm1d) {
1724 InstanceNorm1d instance_norm(5);
1725 instance_norm->eval();
1726
1727 auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1728 auto output = instance_norm->forward(input);
1729 auto expected = torch::tensor(
1730 {{{-1.0000, 1.0000},
1731 {-1.0000, 1.0000},
1732 {-1.0000, 1.0000},
1733 {-1.0000, 1.0000},
1734 {-1.0000, 1.0000}},
1735 {{-1.0000, 1.0000},
1736 {-1.0000, 1.0000},
1737 {-1.0000, 1.0000},
1738 {-1.0000, 1.0000},
1739 {-1.0000, 1.0000}}});
1740 ASSERT_TRUE(output.allclose(expected, 1e-3));
1741 auto s = output.sum();
1742 s.backward();
1743
1744 ASSERT_EQ(input.sizes(), input.grad().sizes());
1745 }
1746
TEST_F(ModulesTest,InstanceNorm2dStateful)1747 TEST_F(ModulesTest, InstanceNorm2dStateful) {
1748 InstanceNorm2d instance_norm(
1749 InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
1750
1751 ASSERT_TRUE(instance_norm->options.track_running_stats());
1752
1753 ASSERT_TRUE(instance_norm->running_mean.defined());
1754 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1755 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1756
1757 ASSERT_TRUE(instance_norm->running_var.defined());
1758 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1759 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1760
1761 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1762 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1763
1764 ASSERT_TRUE(instance_norm->options.affine());
1765
1766 ASSERT_TRUE(instance_norm->weight.defined());
1767 ASSERT_EQ(instance_norm->weight.dim(), 1);
1768 ASSERT_EQ(instance_norm->weight.size(0), 5);
1769
1770 ASSERT_TRUE(instance_norm->bias.defined());
1771 ASSERT_EQ(instance_norm->bias.dim(), 1);
1772 ASSERT_EQ(instance_norm->bias.size(0), 5);
1773 }
1774
TEST_F(ModulesTest,InstanceNorm2dStateless)1775 TEST_F(ModulesTest, InstanceNorm2dStateless) {
1776 InstanceNorm2d instance_norm(
1777 InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
1778
1779 ASSERT_FALSE(instance_norm->running_mean.defined());
1780 ASSERT_FALSE(instance_norm->running_var.defined());
1781 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1782 ASSERT_FALSE(instance_norm->weight.defined());
1783 ASSERT_FALSE(instance_norm->bias.defined());
1784 }
1785
TEST_F(ModulesTest,InstanceNorm2d)1786 TEST_F(ModulesTest, InstanceNorm2d) {
1787 InstanceNorm2d instance_norm(5);
1788 instance_norm->eval();
1789
1790 auto input =
1791 torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1792 auto output = instance_norm->forward(input);
1793 auto expected = torch::tensor(
1794 {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1795 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1796 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1797 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1798 {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
1799 {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1800 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1801 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1802 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1803 {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
1804 ASSERT_TRUE(output.allclose(expected, 1e-3));
1805 auto s = output.sum();
1806 s.backward();
1807
1808 ASSERT_EQ(input.sizes(), input.grad().sizes());
1809 }
1810
TEST_F(ModulesTest,InstanceNorm3dStateful)1811 TEST_F(ModulesTest, InstanceNorm3dStateful) {
1812 InstanceNorm3d instance_norm(
1813 InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
1814
1815 ASSERT_TRUE(instance_norm->options.track_running_stats());
1816
1817 ASSERT_TRUE(instance_norm->running_mean.defined());
1818 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1819 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1820
1821 ASSERT_TRUE(instance_norm->running_var.defined());
1822 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1823 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1824
1825 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1826 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1827
1828 ASSERT_TRUE(instance_norm->options.affine());
1829
1830 ASSERT_TRUE(instance_norm->weight.defined());
1831 ASSERT_EQ(instance_norm->weight.dim(), 1);
1832 ASSERT_EQ(instance_norm->weight.size(0), 5);
1833
1834 ASSERT_TRUE(instance_norm->bias.defined());
1835 ASSERT_EQ(instance_norm->bias.dim(), 1);
1836 ASSERT_EQ(instance_norm->bias.size(0), 5);
1837 }
1838
TEST_F(ModulesTest,InstanceNorm3dStateless)1839 TEST_F(ModulesTest, InstanceNorm3dStateless) {
1840 InstanceNorm3d instance_norm(
1841 InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
1842
1843 ASSERT_FALSE(instance_norm->running_mean.defined());
1844 ASSERT_FALSE(instance_norm->running_var.defined());
1845 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1846 ASSERT_FALSE(instance_norm->weight.defined());
1847 ASSERT_FALSE(instance_norm->bias.defined());
1848 }
1849
TEST_F(ModulesTest,InstanceNorm3d)1850 TEST_F(ModulesTest, InstanceNorm3d) {
1851 InstanceNorm3d instance_norm(5);
1852 instance_norm->eval();
1853
1854 auto input =
1855 torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1856 auto output = instance_norm->forward(input);
1857 auto expected = torch::tensor(
1858 {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1859 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1860 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1861 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1862 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1863 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1864 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1865 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1866 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1867 {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
1868 {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1869 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1870 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1871 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1872 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1873 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1874 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1875 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1876 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1877 {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
1878 ASSERT_TRUE(output.allclose(expected, 1e-3));
1879 auto s = output.sum();
1880 s.backward();
1881
1882 ASSERT_EQ(input.sizes(), input.grad().sizes());
1883 }
1884
TEST_F(ModulesTest,Linear_CUDA)1885 TEST_F(ModulesTest, Linear_CUDA) {
1886 Linear model(5, 2);
1887 model->to(torch::kCUDA);
1888 auto x =
1889 torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
1890 auto y = model(x);
1891 torch::Tensor s = y.sum();
1892
1893 s.backward();
1894 ASSERT_EQ(y.ndimension(), 2);
1895 ASSERT_EQ(s.ndimension(), 0);
1896 ASSERT_EQ(y.size(0), 10);
1897 ASSERT_EQ(y.size(1), 2);
1898
1899 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1900 }
1901
TEST_F(ModulesTest,Linear2_CUDA)1902 TEST_F(ModulesTest, Linear2_CUDA) {
1903 Linear model(5, 2);
1904 model->to(torch::kCUDA);
1905 model->to(torch::kCPU);
1906 auto x = torch::randn({10, 5}, torch::requires_grad());
1907 auto y = model(x);
1908 torch::Tensor s = y.sum();
1909
1910 s.backward();
1911 ASSERT_EQ(y.ndimension(), 2);
1912 ASSERT_EQ(s.ndimension(), 0);
1913 ASSERT_EQ(y.size(0), 10);
1914 ASSERT_EQ(y.size(1), 2);
1915
1916 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1917 }
1918
TEST_F(ModulesTest,L1Loss)1919 TEST_F(ModulesTest, L1Loss) {
1920 L1Loss loss;
1921 auto input = torch::randn({5, 6}, torch::requires_grad());
1922 auto target = torch::empty({5, 6}).random_(2);
1923 auto output = loss->forward(torch::sigmoid(input), target);
1924 auto s = output.sum();
1925 s.backward();
1926
1927 ASSERT_EQ(output.sizes(), std::vector<int64_t>());
1928 ASSERT_EQ(input.sizes(), input.grad().sizes());
1929 }
1930
TEST_F(ModulesTest,MSELoss)1931 TEST_F(ModulesTest, MSELoss) {
1932 MSELoss loss;
1933 auto input = torch::randn({5, 6}, torch::requires_grad());
1934 auto target = torch::empty({5, 6}).random_(2);
1935 auto output = loss->forward(torch::sigmoid(input), target);
1936 auto s = output.sum();
1937 s.backward();
1938
1939 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1940 ASSERT_EQ(input.sizes(), input.grad().sizes());
1941 }
1942
TEST_F(ModulesTest,BCELoss)1943 TEST_F(ModulesTest, BCELoss) {
1944 BCELoss loss;
1945 auto input = torch::randn({5, 6}, torch::requires_grad());
1946 auto target = torch::empty({5, 6}).random_(2);
1947 auto output = loss->forward(torch::sigmoid(input), target);
1948 auto s = output.sum();
1949 s.backward();
1950
1951 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1952 ASSERT_EQ(input.sizes(), input.grad().sizes());
1953 }
1954
TEST_F(ModulesTest,KLDivLoss)1955 TEST_F(ModulesTest, KLDivLoss) {
1956 KLDivLoss loss;
1957 auto input = torch::randn({5, 6}, torch::requires_grad());
1958 auto target = torch::empty({5, 6}).random_(2);
1959 auto output = loss->forward(torch::sigmoid(input), target);
1960 auto s = output.sum();
1961 s.backward();
1962
1963 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1964 ASSERT_EQ(input.sizes(), input.grad().sizes());
1965 }
1966
TEST_F(ModulesTest,HingeEmbeddingLoss)1967 TEST_F(ModulesTest, HingeEmbeddingLoss) {
1968 HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2));
1969 auto input = torch::tensor(
1970 {{2, 22, 4}, {20, 10, 0}},
1971 torch::dtype(torch::kFloat).requires_grad(true));
1972 auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
1973 auto output = loss->forward(input, target);
1974 auto expected = torch::tensor({10}, torch::kFloat);
1975 auto s = output.sum();
1976 s.backward();
1977
1978 ASSERT_TRUE(output.allclose(expected));
1979 ASSERT_EQ(input.sizes(), input.grad().sizes());
1980 }
1981
TEST_F(ModulesTest,MultiMarginLoss)1982 TEST_F(ModulesTest, MultiMarginLoss) {
1983 auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
1984 MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight));
1985 auto input = torch::tensor(
1986 {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
1987 torch::dtype(torch::kFloat).requires_grad(true));
1988 auto target = torch::tensor({2, 1, 0}, torch::kLong);
1989 auto output = loss->forward(input, target);
1990 auto expected = torch::tensor({0.305556}, torch::kFloat);
1991 auto s = output.sum();
1992 s.backward();
1993
1994 ASSERT_TRUE(output.allclose(expected, 1e-04));
1995 ASSERT_EQ(input.sizes(), input.grad().sizes());
1996 }
1997
TEST_F(ModulesTest,CosineEmbeddingLoss)1998 TEST_F(ModulesTest, CosineEmbeddingLoss) {
1999 CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5));
2000 auto input1 = torch::tensor(
2001 {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
2002 auto input2 = torch::tensor(
2003 {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true));
2004 auto target = torch::tensor({1, -1});
2005 auto output = cos(input1, input2, target);
2006 auto expected = torch::tensor({0.1004}, torch::kFloat);
2007 auto s = output.sum();
2008 s.backward();
2009
2010 ASSERT_TRUE(output.allclose(expected, 1e-4));
2011 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2012 ASSERT_EQ(input2.sizes(), input2.grad().sizes());
2013 }
2014
TEST_F(ModulesTest,SmoothL1LossDefaultOptions)2015 TEST_F(ModulesTest, SmoothL1LossDefaultOptions) {
2016 SmoothL1Loss loss;
2017 auto input = torch::tensor(
2018 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2019 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2020 auto output = loss(input, target);
2021 auto expected = torch::tensor(0.0233335, torch::kFloat);
2022 auto s = output.sum();
2023 s.backward();
2024
2025 ASSERT_TRUE(output.allclose(expected));
2026 ASSERT_EQ(input.sizes(), input.grad().sizes());
2027 }
2028
TEST_F(ModulesTest,HuberLossDefaultOptions)2029 TEST_F(ModulesTest, HuberLossDefaultOptions) {
2030 HuberLoss loss;
2031 auto input = torch::tensor(
2032 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2033 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2034 auto output = loss(input, target);
2035 auto expected = torch::tensor(0.0233335, torch::kFloat);
2036 auto s = output.sum();
2037 s.backward();
2038
2039 ASSERT_TRUE(output.allclose(expected));
2040 ASSERT_EQ(input.sizes(), input.grad().sizes());
2041 }
2042
TEST_F(ModulesTest,MultiLabelMarginLossDefaultOptions)2043 TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) {
2044 MultiLabelMarginLoss loss;
2045 auto input = torch::tensor(
2046 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2047 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2048 auto output = loss->forward(input, target);
2049 auto expected = torch::tensor({0.8500}, torch::kFloat);
2050 auto s = output.sum();
2051 s.backward();
2052
2053 ASSERT_TRUE(output.allclose(expected));
2054 ASSERT_EQ(input.sizes(), input.grad().sizes());
2055 }
2056
TEST_F(ModulesTest,SmoothL1LossNoReduction)2057 TEST_F(ModulesTest, SmoothL1LossNoReduction) {
2058 SmoothL1Loss loss(/*reduction=*/torch::kNone);
2059 auto input = torch::tensor(
2060 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2061 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2062 auto output = loss(input, target);
2063 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2064 auto s = output.sum();
2065 s.backward();
2066
2067 ASSERT_TRUE(output.allclose(expected));
2068 ASSERT_EQ(input.sizes(), input.grad().sizes());
2069 }
2070
TEST_F(ModulesTest,HuberLossNoReduction)2071 TEST_F(ModulesTest, HuberLossNoReduction) {
2072 HuberLoss loss(/*reduction=*/torch::kNone);
2073 auto input = torch::tensor(
2074 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2075 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2076 auto output = loss(input, target);
2077 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2078 auto s = output.sum();
2079 s.backward();
2080
2081 ASSERT_TRUE(output.allclose(expected));
2082 ASSERT_EQ(input.sizes(), input.grad().sizes());
2083 }
2084
TEST_F(ModulesTest,MultiLabelMarginLossNoReduction)2085 TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) {
2086 MultiLabelMarginLoss loss(torch::kNone);
2087 auto input = torch::tensor(
2088 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2089 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2090 auto output = loss->forward(input, target);
2091 auto expected = torch::tensor({0.8500}, torch::kFloat);
2092 auto s = output.sum();
2093 s.backward();
2094
2095 ASSERT_TRUE(output.allclose(expected));
2096 ASSERT_EQ(input.sizes(), input.grad().sizes());
2097 }
2098
TEST_F(ModulesTest,SmoothL1LossBeta)2099 TEST_F(ModulesTest, SmoothL1LossBeta) {
2100 auto options = SmoothL1LossOptions().beta(0.2);
2101 SmoothL1Loss loss(options);
2102 auto input = torch::tensor(
2103 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2104 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2105 auto output = loss(input, target);
2106 auto expected = torch::tensor(0.108333, torch::kFloat);
2107 auto s = output.sum();
2108 s.backward();
2109
2110 ASSERT_TRUE(output.allclose(expected));
2111 ASSERT_EQ(input.sizes(), input.grad().sizes());
2112 }
2113
TEST_F(ModulesTest,HuberLossDelta)2114 TEST_F(ModulesTest, HuberLossDelta) {
2115 auto options = HuberLossOptions().delta(0.2);
2116 HuberLoss loss(options);
2117 auto input = torch::tensor(
2118 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2119 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2120 auto output = loss(input, target);
2121 auto expected = torch::tensor(0.0216666, torch::kFloat);
2122 auto s = output.sum();
2123 s.backward();
2124
2125 ASSERT_TRUE(output.allclose(expected));
2126 ASSERT_EQ(input.sizes(), input.grad().sizes());
2127 }
2128
TEST_F(ModulesTest,TripletMarginLoss)2129 TEST_F(ModulesTest, TripletMarginLoss) {
2130 TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0));
2131 auto anchor = torch::tensor(
2132 {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true));
2133 auto positive = torch::tensor(
2134 {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2135 auto negative = torch::tensor(
2136 {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true));
2137 auto output = loss->forward(anchor, positive, negative);
2138 auto expected = torch::tensor({0.}, torch::kFloat);
2139 auto s = output.sum();
2140 s.backward();
2141
2142 ASSERT_TRUE(output.allclose(expected, 1e-04));
2143 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2144 }
2145
TEST_F(ModulesTest,TripletMarginWithDistanceLossDefaultParity)2146 TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
2147 // Check that if we use torch::pairwise_distance with the default
2148 // TripletMarginLoss options as our distance function, the outputs
2149 // are equal (i.e., equal under defaults).
2150
2151 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2152 torch::kSum, torch::kMean, torch::kNone};
2153 std::vector<float> margins = {0.5, 1.0, 1.5};
2154 std::vector<bool> swaps = {true, false};
2155
2156 for (auto& reduction : reductions) {
2157 for (auto& margin : margins) {
2158 for (const auto swap : swaps) {
2159 auto anchor = torch::randn(
2160 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2161 auto positive = torch::randn(
2162 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2163 auto negative = torch::randn(
2164 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2165
2166 auto basicOptions =
2167 TripletMarginLossOptions().reduction(reduction).margin(margin).swap(
2168 swap);
2169 auto distanceOptions = TripletMarginWithDistanceLossOptions()
2170 .reduction(reduction)
2171 .margin(margin)
2172 .swap(swap);
2173 TripletMarginLoss basicLoss(basicOptions);
2174 TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
2175
2176 auto basicOutput = basicLoss->forward(anchor, positive, negative);
2177 auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
2178 auto basicOperatorOutput = basicLoss(anchor, positive, negative);
2179 auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);
2180
2181 ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
2182 ASSERT_TRUE(
2183 distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
2184 ASSERT_TRUE(
2185 distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));
2186
2187 // handle for torch::kNone reduction
2188 auto sum = distanceOutput.sum();
2189 sum.backward();
2190 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2191 ASSERT_EQ(positive.sizes(), positive.grad().sizes());
2192 ASSERT_EQ(negative.sizes(), negative.grad().sizes());
2193 }
2194 }
2195 }
2196 }
2197
TEST_F(ModulesTest,TripletMarginWithDistanceLossFunctionalParity)2198 TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
2199 // Check for parity between F::triplet_margin_with_distance_loss and
2200 // TripletMarginWithDistanceLoss.
2201 auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2202 return torch::pairwise_distance(x, y);
2203 };
2204 auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2205 return 1.0 - torch::cosine_similarity(x, y);
2206 };
2207 std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
2208 distance_functions = {pairwise_distance, cosine_distance};
2209
2210 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2211 torch::kSum, torch::kMean, torch::kNone};
2212 std::vector<float> margins = {0.5, 1.0, 1.5};
2213 std::vector<bool> swaps = {true, false};
2214
2215 for (auto& function : distance_functions) {
2216 for (auto& reduction : reductions) {
2217 for (auto& margin : margins) {
2218 for (const auto swap : swaps) {
2219 auto moduleOptions = TripletMarginWithDistanceLossOptions()
2220 .distance_function(function)
2221 .reduction(reduction)
2222 .margin(margin)
2223 .swap(swap);
2224 auto functionOptions =
2225 torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
2226 .distance_function(function)
2227 .reduction(reduction)
2228 .margin(margin)
2229 .swap(swap);
2230
2231 auto anchor = torch::randn(
2232 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2233 auto positive = torch::randn(
2234 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2235 auto negative = torch::randn(
2236 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2237
2238 TripletMarginWithDistanceLoss distanceLoss(moduleOptions);
2239
2240 auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
2241 auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
2242 auto functionOutput =
2243 torch::nn::functional::triplet_margin_with_distance_loss(
2244 anchor, positive, negative, functionOptions);
2245
2246 ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
2247 ASSERT_TRUE(
2248 moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
2249 }
2250 }
2251 }
2252 }
2253 }
2254
TEST_F(ModulesTest,NLLLoss)2255 TEST_F(ModulesTest, NLLLoss) {
2256 NLLLoss loss;
2257 auto input = torch::tensor(
2258 {{-0.1315, -3.1315, -2.5315},
2259 {-3.7038, -0.1038, -2.6038},
2260 {-2.3422, -1.3422, -0.4422}},
2261 torch::dtype(torch::kFloat).requires_grad(true));
2262 auto target = torch::tensor({1, 0, 2}, torch::kLong);
2263 auto output = loss->forward(input, target);
2264 auto expected = torch::tensor(2.4258, torch::kFloat);
2265 auto s = output.sum();
2266 s.backward();
2267
2268 ASSERT_TRUE(output.allclose(expected, 1e-04));
2269 ASSERT_TRUE(
2270 NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean))
2271 ->forward(input, target)
2272 .allclose(expected, 1e-04));
2273 }
2274
TEST_F(ModulesTest,CrossEntropyLoss)2275 TEST_F(ModulesTest, CrossEntropyLoss) {
2276 CrossEntropyLoss loss;
2277 auto input = torch::tensor(
2278 {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2279 auto target = torch::tensor({0, 1}, torch::kLong);
2280 auto output = loss->forward(input, target);
2281 auto expected = torch::tensor(0.6931, torch::kFloat);
2282 auto s = output.sum();
2283 s.backward();
2284
2285 ASSERT_TRUE(output.allclose(expected, 1e-04));
2286 ASSERT_EQ(input.sizes(), input.grad().sizes());
2287 ASSERT_TRUE(
2288 CrossEntropyLoss(
2289 CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean))
2290 ->forward(input, target)
2291 .allclose(expected, 1e-04));
2292
2293 // label smoothing with class indices
2294 loss = CrossEntropyLoss(
2295 CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean));
2296 input = torch::tensor(
2297 {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2298 target = torch::tensor({0, 1}, torch::kLong);
2299 output = loss->forward(input, target);
2300 expected = torch::tensor(0.3326, torch::kFloat);
2301 s = output.sum();
2302 s.backward();
2303
2304 ASSERT_TRUE(output.allclose(expected, 1e-04));
2305 ASSERT_EQ(input.sizes(), input.grad().sizes());
2306
2307 // label smoothing with with target probabilities
2308 loss = CrossEntropyLoss(
2309 CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean));
2310 input = torch::tensor(
2311 {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2312 target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
2313 output = loss->forward(input, target);
2314 expected = torch::tensor(0.5701, torch::kFloat);
2315 s = output.sum();
2316 s.backward();
2317
2318 ASSERT_TRUE(output.allclose(expected, 1e-04));
2319 ASSERT_EQ(input.sizes(), input.grad().sizes());
2320 }
2321
TEST_F(ModulesTest,CosineSimilarity)2322 TEST_F(ModulesTest, CosineSimilarity) {
2323 CosineSimilarity cos(CosineSimilarityOptions().dim(1));
2324 auto input1 = torch::tensor(
2325 {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2326 auto input2 = torch::tensor(
2327 {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2328 auto output = cos->forward(input1, input2);
2329 auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
2330 auto s = output.sum();
2331 s.backward();
2332
2333 ASSERT_TRUE(output.allclose(expected, 1e-04));
2334 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2335 }
2336
TEST_F(ModulesTest,SoftMarginLossDefaultOptions)2337 TEST_F(ModulesTest, SoftMarginLossDefaultOptions) {
2338 SoftMarginLoss loss;
2339 auto input = torch::tensor(
2340 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2341 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2342 auto output = loss->forward(input, target);
2343 auto expected = torch::tensor({1.3767317}, torch::kFloat);
2344 auto s = output.sum();
2345 s.backward();
2346
2347 ASSERT_TRUE(output.allclose(expected));
2348 ASSERT_EQ(input.sizes(), input.grad().sizes());
2349 }
2350
TEST_F(ModulesTest,MultiLabelSoftMarginLossDefaultOptions)2351 TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
2352 MultiLabelSoftMarginLoss loss;
2353 auto input = torch::tensor(
2354 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2355 torch::dtype(torch::kFloat).requires_grad(true));
2356 auto target =
2357 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2358 auto output = loss->forward(input, target);
2359 auto expected = torch::tensor({0.7608436}, torch::kFloat);
2360 auto s = output.sum();
2361 s.backward();
2362
2363 ASSERT_TRUE(output.allclose(expected));
2364 ASSERT_EQ(input.sizes(), input.grad().sizes());
2365 }
2366
TEST_F(ModulesTest,SoftMarginLossNoReduction)2367 TEST_F(ModulesTest, SoftMarginLossNoReduction) {
2368 SoftMarginLoss loss(torch::kNone);
2369 auto input = torch::tensor(
2370 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2371 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2372 auto output = loss->forward(input, target);
2373 auto expected = torch::tensor(
2374 {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
2375 auto s = output.sum();
2376 s.backward();
2377
2378 ASSERT_TRUE(output.allclose(expected));
2379 ASSERT_EQ(input.sizes(), input.grad().sizes());
2380 }
2381
TEST_F(ModulesTest,MultiLabelSoftMarginLossWeightedNoReduction)2382 TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
2383 auto input = torch::tensor(
2384 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2385 torch::dtype(torch::kFloat).requires_grad(true));
2386 auto target =
2387 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2388 auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
2389 auto options =
2390 MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight);
2391 MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options);
2392 auto output = loss->forward(input, target);
2393 auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
2394 auto s = output.sum();
2395 s.backward();
2396
2397 ASSERT_TRUE(output.allclose(expected));
2398 ASSERT_EQ(input.sizes(), input.grad().sizes());
2399 }
2400
TEST_F(ModulesTest,PairwiseDistance)2401 TEST_F(ModulesTest, PairwiseDistance) {
2402 PairwiseDistance dist(PairwiseDistanceOptions().p(1));
2403 auto input1 = torch::tensor(
2404 {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2405 auto input2 = torch::tensor(
2406 {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2407 auto output = dist->forward(input1, input2);
2408 auto expected = torch::tensor({6, 6}, torch::kFloat);
2409 auto s = output.sum();
2410 s.backward();
2411
2412 ASSERT_TRUE(output.allclose(expected));
2413 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2414 }
2415
TEST_F(ModulesTest,ELU)2416 TEST_F(ModulesTest, ELU) {
2417 const auto size = 3;
2418 for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2419 for (const auto inplace : {false, true}) {
2420 ELU model{ELUOptions().alpha(alpha).inplace(inplace)};
2421 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2422 x.resize_({size, size, size});
2423 if (!inplace) {
2424 x.requires_grad_(true);
2425 }
2426 auto x_orig = x.clone();
2427 auto y = model(x);
2428 torch::Tensor s = y.sum();
2429
2430 ASSERT_EQ(s.ndimension(), 0);
2431
2432 ASSERT_EQ(y.ndimension(), 3);
2433 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2434 auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2435 torch::min(torch::zeros_like(x_orig),
2436 alpha * (torch::exp(x_orig) - 1.0));
2437 ASSERT_TRUE(torch::allclose(y, y_exp));
2438 if (inplace) {
2439 ASSERT_TRUE(torch::allclose(x, y_exp));
2440 } else {
2441 s.backward();
2442 }
2443 }
2444 }
2445 }
2446
TEST_F(ModulesTest,SELU)2447 TEST_F(ModulesTest, SELU) {
2448 for (const auto inplace : {false, true}) {
2449 SELU model(inplace);
2450 auto input = torch::randn({5, 5});
2451 if (!inplace) {
2452 input.requires_grad_(true);
2453 }
2454 auto input_orig = input.clone();
2455 auto output = model->forward(input);
2456 const double scale = 1.0507009873554804934193349852946;
2457 const double alpha = 1.6732632423543772848170429916717;
2458 auto zero = torch::zeros_like(input);
2459 auto expected = scale *
2460 (torch::max(zero, input_orig) +
2461 torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
2462 auto s = output.sum();
2463
2464 ASSERT_EQ(s.ndimension(), 0);
2465 ASSERT_TRUE(output.allclose(expected));
2466 if (inplace) {
2467 ASSERT_TRUE(input.allclose(expected));
2468 } else {
2469 s.backward();
2470 }
2471 }
2472 }
2473
TEST_F(ModulesTest,Hardshrink)2474 TEST_F(ModulesTest, Hardshrink) {
2475 const auto size = 3;
2476 for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
2477 Hardshrink model{HardshrinkOptions().lambda(lambda)};
2478 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2479 x.resize_({size, size, size}).set_requires_grad(true);
2480 auto y = model(x);
2481 torch::Tensor s = y.sum();
2482
2483 s.backward();
2484 ASSERT_EQ(s.ndimension(), 0);
2485 ASSERT_EQ(y.ndimension(), 3);
2486 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2487 auto y_exp = (x.abs() > lambda) * x;
2488 ASSERT_TRUE(torch::allclose(y, y_exp));
2489 }
2490 }
2491
TEST_F(ModulesTest,Hardtanh)2492 TEST_F(ModulesTest, Hardtanh) {
2493 const auto size = 3;
2494 for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
2495 for (const auto max_val : {0.42, 1.0, 4.2}) {
2496 for (const auto inplace : {false, true}) {
2497 Hardtanh model{
2498 HardtanhOptions().min_val(min_val).max_val(max_val).inplace(
2499 inplace)};
2500 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2501 x.resize_({size, size, size});
2502 if (!inplace) {
2503 x.requires_grad_(true);
2504 }
2505 auto x_orig = x.clone();
2506 auto y = model(x);
2507 torch::Tensor s = y.sum();
2508
2509 ASSERT_EQ(s.ndimension(), 0);
2510 ASSERT_EQ(y.ndimension(), 3);
2511 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2512 auto y_exp = (x_orig < min_val) * min_val +
2513 ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig +
2514 (x_orig > max_val) * max_val;
2515 ASSERT_TRUE(torch::allclose(y, y_exp));
2516 if (inplace) {
2517 ASSERT_TRUE(torch::allclose(x, y_exp));
2518 } else {
2519 s.backward();
2520 }
2521 }
2522 }
2523 }
2524 }
2525
TEST_F(ModulesTest,HardtanhMinValGEMaxVal)2526 TEST_F(ModulesTest, HardtanhMinValGEMaxVal) {
2527 ASSERT_THROWS_WITH(
2528 Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)},
2529 "max_val must be greater than min_val");
2530 ASSERT_THROWS_WITH(
2531 Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)},
2532 "max_val must be greater than min_val");
2533
2534 Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)};
2535 ht->options.min_val(0.42);
2536 ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2537 ht->options.max_val(-0.42);
2538 ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2539 }
2540
TEST_F(ModulesTest,LeakyReLU)2541 TEST_F(ModulesTest, LeakyReLU) {
2542 const auto size = 3;
2543 for (const auto inplace : {false, true}) {
2544 for (const auto negative_slope : {0.0, 0.42, 1.0}) {
2545 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2546 LeakyReLU model{
2547 LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
2548 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2549 x.resize_({size, size, size});
2550 if (!inplace) {
2551 x.requires_grad_(true);
2552 }
2553 auto x_orig = x.clone();
2554 auto y = model(x);
2555 torch::Tensor s = y.sum();
2556
2557 ASSERT_EQ(s.ndimension(), 0);
2558 ASSERT_EQ(y.ndimension(), 3);
2559 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2560 auto y_exp =
2561 (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
2562 ASSERT_TRUE(torch::allclose(y, y_exp));
2563 if (inplace) {
2564 ASSERT_TRUE(torch::allclose(x, y_exp));
2565 } else {
2566 s.backward();
2567 }
2568 }
2569 }
2570 }
2571 }
2572
TEST_F(ModulesTest,LogSigmoid)2573 TEST_F(ModulesTest, LogSigmoid) {
2574 const auto size = 3;
2575 LogSigmoid model;
2576 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2577 x.resize_({size, size, size}).set_requires_grad(true);
2578 auto y = model(x);
2579 torch::Tensor s = y.sum();
2580
2581 s.backward();
2582 ASSERT_EQ(s.ndimension(), 0);
2583
2584 ASSERT_EQ(y.ndimension(), 3);
2585 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2586 auto y_exp = torch::log(
2587 torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
2588 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
2589 }
2590
TEST_F(ModulesTest,Softmax)2591 TEST_F(ModulesTest, Softmax) {
2592 Softmax m(/*dim=*/1);
2593 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2594 auto output = m(input);
2595 auto sum = torch::sum(torch::exp(input), 1);
2596
2597 for (const auto i : c10::irange(2)) {
2598 auto expected = torch::exp(input[i]) / sum[i];
2599 ASSERT_TRUE(torch::allclose(output[i], expected));
2600 }
2601 }
2602
TEST_F(ModulesTest,Softmin)2603 TEST_F(ModulesTest, Softmin) {
2604 Softmin m(/*dim=*/1);
2605 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2606 auto output = m(input);
2607 auto sum = torch::sum(torch::exp(-input), 1);
2608
2609 for (const auto i : c10::irange(2)) {
2610 auto expected = torch::exp(-input[i]) / sum[i];
2611 ASSERT_TRUE(torch::allclose(output[i], expected));
2612 }
2613 }
2614
TEST_F(ModulesTest,LogSoftmax)2615 TEST_F(ModulesTest, LogSoftmax) {
2616 LogSoftmax m(/*dim=*/1);
2617 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2618 auto output = m(input);
2619 auto sum = torch::sum(torch::exp(input), 1);
2620
2621 for (const auto i : c10::irange(2)) {
2622 auto expected = torch::log(torch::exp(input[i]) / sum[i]);
2623 ASSERT_TRUE(torch::allclose(output[i], expected));
2624 }
2625 }
2626
TEST_F(ModulesTest,AdaptiveLogSoftmaxWithLoss)2627 TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
2628 {
2629 // log_probs actually returns log_proba
2630 AdaptiveLogSoftmaxWithLoss asfm(
2631 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2632 auto x = torch::randn({4, 8});
2633 auto logprob_out = asfm->log_prob(x);
2634 ASSERT_TRUE(
2635 torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
2636 }
2637 {
2638 // test predict
2639 AdaptiveLogSoftmaxWithLoss asfm(
2640 AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
2641 .div_value(2.)
2642 .head_bias(true));
2643 auto x = torch::randn({64, 8});
2644 auto logprob_out = asfm->log_prob(x);
2645 auto predict_out = asfm->predict(x);
2646 ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
2647 }
2648 {
2649 // cluster sizes
2650 AdaptiveLogSoftmaxWithLoss asfm(
2651 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2652 auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
2653 auto y = torch::tensor({0, 17}, torch::kLong);
2654 auto asm_out = asfm(x, y);
2655 ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
2656 }
2657 {
2658 // forward returns the same thing as log_probs
2659 AdaptiveLogSoftmaxWithLoss asfm(
2660 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2661 auto x = torch::randn({4, 8});
2662 auto logprob_out = asfm->log_prob(x);
2663 NLLLoss nll_loss;
2664
2665 for (const auto v : c10::irange(4)) {
2666 auto y = torch::full({4}, v, torch::kLong);
2667 auto asm_out = asfm(x, y);
2668 auto out = asm_out.output;
2669 auto loss = torch::tensor(asm_out.loss, torch::kFloat);
2670 auto expected = nll_loss->forward(logprob_out, y);
2671
2672 ASSERT_TRUE(torch::allclose(loss, expected));
2673 ASSERT_TRUE(torch::allclose(
2674 out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
2675 }
2676 }
2677 {
2678 // test no batch dim
2679 AdaptiveLogSoftmaxWithLoss asfm(
2680 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2681 auto x = torch::randn({1, 16});
2682 auto y = torch::tensor({17});
2683 auto x2 = x.squeeze(0);
2684 auto y2 = y.squeeze(0);
2685 ASSERT_TRUE(
2686 torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
2687 }
2688 {
2689 // test div_value
2690 auto options =
2691 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
2692 ASSERT_THROWS_WITH(
2693 AdaptiveLogSoftmaxWithLoss(options),
2694 "div_value should not be equal to 0");
2695
2696 options =
2697 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
2698 ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
2699 }
2700 }
2701
TEST_F(ModulesTest,Softmax2d)2702 TEST_F(ModulesTest, Softmax2d) {
2703 Softmax2d m;
2704 auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
2705 auto output = m(input);
2706 auto sum = torch::sum(torch::exp(input), 1);
2707
2708 for (const auto i : c10::irange(1)) {
2709 for (const auto j : c10::irange(2)) {
2710 for (const auto k : c10::irange(3)) {
2711 for (const auto l : c10::irange(4)) {
2712 auto expected = torch::exp(input[i][j][k][l]) / sum[i][k][l];
2713 ASSERT_TRUE(torch::allclose(output[i][j][k][l], expected));
2714 }
2715 }
2716 }
2717 }
2718 }
2719
TEST_F(ModulesTest,PReLU)2720 TEST_F(ModulesTest, PReLU) {
2721 const auto num_parameters = 42;
2722 const auto init = 0.42;
2723
2724 PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)};
2725
2726 ASSERT_EQ(model->weight.sizes(), std::vector<int64_t>({num_parameters}));
2727 ASSERT_TRUE(
2728 torch::allclose(model->weight, torch::full(num_parameters, init)));
2729
2730 const auto x = torch::rand({100, num_parameters}) * 200 - 100;
2731 const auto y = model(x);
2732 const auto s = y.sum();
2733
2734 s.backward();
2735 ASSERT_EQ(s.ndimension(), 0);
2736
2737 ASSERT_EQ(y.ndimension(), x.ndimension());
2738 ASSERT_EQ(y.sizes(), x.sizes());
2739 const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x;
2740 ASSERT_TRUE(torch::allclose(y, y_exp));
2741 }
2742
TEST_F(ModulesTest,ReLU)2743 TEST_F(ModulesTest, ReLU) {
2744 for (const auto inplace : {false, true}) {
2745 const auto size = 3;
2746 ReLU model(inplace);
2747 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2748 x.resize_({size, size, size});
2749 if (!inplace) {
2750 x.requires_grad_(true);
2751 }
2752 auto x_orig = x.clone();
2753 auto y = model(x);
2754 torch::Tensor s = y.sum();
2755
2756 ASSERT_EQ(s.ndimension(), 0);
2757 ASSERT_EQ(y.ndimension(), 3);
2758 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2759 auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig;
2760 ASSERT_TRUE(torch::allclose(y, y_exp));
2761 if (inplace) {
2762 ASSERT_TRUE(torch::allclose(x, y_exp));
2763 } else {
2764 s.backward();
2765 }
2766 }
2767 }
2768
TEST_F(ModulesTest,ReLU6)2769 TEST_F(ModulesTest, ReLU6) {
2770 for (const auto inplace : {false, true}) {
2771 const auto size = 3;
2772 ReLU6 model(inplace);
2773 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2774 x.resize_({size, size, size});
2775 if (!inplace) {
2776 x.requires_grad_(true);
2777 }
2778 auto x_orig = x.clone();
2779 auto y = model(x);
2780 torch::Tensor s = y.sum();
2781
2782 ASSERT_EQ(s.ndimension(), 0);
2783 ASSERT_EQ(y.ndimension(), 3);
2784 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2785 auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig +
2786 (x_orig > 6) * 6;
2787 ASSERT_TRUE(torch::allclose(y, y_exp));
2788 if (inplace) {
2789 ASSERT_TRUE(torch::allclose(x, y_exp));
2790 } else {
2791 s.backward();
2792 }
2793 }
2794 }
2795
TEST_F(ModulesTest,RReLU)2796 TEST_F(ModulesTest, RReLU) {
2797 const auto size = 3;
2798 for (const auto lower : {0.01, 0.1, 0.2}) {
2799 for (const auto upper : {0.3, 0.4, 0.5}) {
2800 for (const auto inplace : {false, true}) {
2801 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2802 RReLU model{
2803 RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
2804 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2805 x.resize_({size, size, size});
2806 if (!inplace) {
2807 x.requires_grad_(true);
2808 }
2809 auto x_orig = x.clone();
2810 auto y = model(x);
2811 torch::Tensor s = y.sum();
2812
2813 ASSERT_EQ(s.ndimension(), 0);
2814 ASSERT_EQ(y.ndimension(), 3);
2815 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2816 auto z =
2817 ((x_orig >= 0) * (x_orig == y) +
2818 (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) *
2819 1.0;
2820 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
2821 if (inplace) {
2822 ASSERT_TRUE(torch::allclose(x, y));
2823 } else {
2824 s.backward();
2825 }
2826 }
2827 }
2828 }
2829 }
2830 }
2831
TEST_F(ModulesTest,CELU)2832 TEST_F(ModulesTest, CELU) {
2833 const auto size = 3;
2834 for (const auto inplace : {false, true}) {
2835 for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
2836 CELU model{CELUOptions().alpha(alpha).inplace(inplace)};
2837 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2838 x.resize_({size, size, size});
2839 if (!inplace) {
2840 x.requires_grad_(true);
2841 }
2842 auto x_orig = x.clone();
2843 auto y = model(x);
2844 torch::Tensor s = y.sum();
2845
2846 ASSERT_EQ(s.ndimension(), 0);
2847 ASSERT_EQ(y.ndimension(), 3);
2848 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2849 auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2850 torch::min(torch::zeros_like(x_orig),
2851 alpha * (torch::exp(x_orig / alpha) - 1.0));
2852 ASSERT_TRUE(torch::allclose(y, y_exp));
2853 if (inplace) {
2854 ASSERT_TRUE(torch::allclose(x, y_exp));
2855 } else {
2856 s.backward();
2857 }
2858 }
2859 }
2860 }
2861
TEST_F(ModulesTest,GLU)2862 TEST_F(ModulesTest, GLU) {
2863 int64_t dim = 1;
2864 GLU model(dim);
2865 auto input = torch::randn({4, 2}, torch::requires_grad());
2866 auto output = model->forward(input);
2867 auto input_size = input.sizes()[dim] / 2;
2868 auto first_half = input.narrow(dim, 0, input_size);
2869 auto second_half = input.narrow(dim, input_size, input_size);
2870 auto expected = first_half * torch::sigmoid(second_half);
2871 auto s = output.sum();
2872 s.backward();
2873
2874 ASSERT_EQ(s.ndimension(), 0);
2875 ASSERT_TRUE(output.allclose(expected));
2876
2877 GLU model_default_options;
2878 ASSERT_TRUE(model_default_options->forward(input).allclose(expected));
2879 }
2880
TEST_F(ModulesTest,GELU)2881 TEST_F(ModulesTest, GELU) {
2882 GELU model(GELUOptions().approximate("none"));
2883 const auto x = torch::linspace(-3.0, 3.0, 100);
2884 const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
2885 const auto y = model(x);
2886 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2887 }
2888
TEST_F(ModulesTest,TanhGELU)2889 TEST_F(ModulesTest, TanhGELU) {
2890 GELU model(GELUOptions().approximate("tanh"));
2891 const auto x = torch::linspace(-3.0, 3.0, 100);
2892 const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
2893 const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
2894 const auto y = model(x);
2895 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2896 }
2897
2898 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_F(ModulesTest,Mish)2899 TEST_F(ModulesTest, Mish) {
2900 Mish model;
2901 auto x = torch::randn(100) * 10;
2902 auto y_exp = x * x.exp().log1p().tanh();
2903 auto y = model(x);
2904
2905 ASSERT_TRUE(torch::allclose(y, y_exp));
2906 }
2907
TEST_F(ModulesTest,Sigmoid)2908 TEST_F(ModulesTest, Sigmoid) {
2909 Sigmoid model;
2910 auto x = torch::randn(100) * 10;
2911 auto y_exp = 1 / (1 + torch::exp(-x));
2912 auto y = model(x);
2913
2914 ASSERT_TRUE(torch::allclose(y, y_exp));
2915 }
2916
TEST_F(ModulesTest,PixelShuffle)2917 TEST_F(ModulesTest, PixelShuffle) {
2918 PixelShuffle module(/*upscale_factor=*/2);
2919 auto x = torch::tensor(
2920 {{{{-17, 19}, {-1, 2}},
2921 {{7, 14}, {-3, 1}},
2922 {{0, -2}, {-12, 14}},
2923 {{-15, 0}, {-3, 9}}}},
2924 torch::kFloat);
2925 auto y_exp = torch::tensor(
2926 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2927 torch::kFloat);
2928 auto y = module(x);
2929
2930 ASSERT_EQ(y.ndimension(), 4);
2931 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
2932 ASSERT_TRUE(y.allclose(y_exp));
2933 }
2934
TEST_F(ModulesTest,PixelUnshuffle)2935 TEST_F(ModulesTest, PixelUnshuffle) {
2936 PixelUnshuffle module(/*downscale_factor=*/2);
2937 auto x = torch::tensor(
2938 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2939 torch::kFloat);
2940 auto y_exp = torch::tensor(
2941 {{{{-17, 19}, {-1, 2}},
2942 {{7, 14}, {-3, 1}},
2943 {{0, -2}, {-12, 14}},
2944 {{-15, 0}, {-3, 9}}}},
2945 torch::kFloat);
2946 auto y = module(x);
2947
2948 ASSERT_EQ(y.ndimension(), 4);
2949 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
2950 ASSERT_TRUE(y.allclose(y_exp));
2951 }
2952
TEST_F(ModulesTest,Softplus)2953 TEST_F(ModulesTest, Softplus) {
2954 const auto size = 3;
2955 for (const auto beta : {0.5, 1.0, 2.0}) {
2956 for (const auto threshold : {1.0, 3.0, 5.0}) {
2957 Softplus model{SoftplusOptions().beta(beta).threshold(threshold)};
2958 auto x = torch::linspace(-3.0, 3.0, 61);
2959 x.resize_({size, size, size});
2960 auto y_exp =
2961 (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
2962 (x > threshold) * x;
2963 auto y = model(x);
2964
2965 ASSERT_EQ(y.ndimension(), 3);
2966 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2967 ASSERT_TRUE(torch::allclose(y, y_exp));
2968 }
2969 }
2970 }
2971
TEST_F(ModulesTest,Softshrink)2972 TEST_F(ModulesTest, Softshrink) {
2973 const auto size = 3;
2974 for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2975 Softshrink model{/*lambda=*/lambda};
2976 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2977 x.resize_({size, size, size}).set_requires_grad(true);
2978 auto y = model(x);
2979 torch::Tensor s = y.sum();
2980
2981 s.backward();
2982 ASSERT_EQ(s.ndimension(), 0);
2983
2984 ASSERT_EQ(y.ndimension(), 3);
2985 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2986 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
2987 ASSERT_TRUE(torch::allclose(y, y_exp));
2988 }
2989 }
2990
TEST_F(ModulesTest,Softsign)2991 TEST_F(ModulesTest, Softsign) {
2992 Softsign model;
2993 auto x = torch::randn(100) * 10;
2994 auto y_exp = x / (1 + x.abs());
2995 auto y = model(x);
2996
2997 ASSERT_TRUE(torch::allclose(y, y_exp));
2998 }
2999
TEST_F(ModulesTest,Tanh)3000 TEST_F(ModulesTest, Tanh) {
3001 Tanh model;
3002 auto x = torch::randn(100) * 10;
3003 auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp());
3004 auto y = model(x);
3005
3006 ASSERT_TRUE(torch::allclose(y, y_exp));
3007 }
3008
TEST_F(ModulesTest,Tanhshrink)3009 TEST_F(ModulesTest, Tanhshrink) {
3010 Tanhshrink model;
3011 auto x = torch::randn(100) * 10;
3012 auto y_exp = x - x.tanh();
3013 auto y = model(x);
3014
3015 ASSERT_TRUE(torch::allclose(y, y_exp));
3016 }
3017
TEST_F(ModulesTest,Threshold)3018 TEST_F(ModulesTest, Threshold) {
3019 const auto size = 3;
3020 for (const auto threshold : {0.5, 1.0, 2.0}) {
3021 for (const auto value : {0.5, 1.0, 2.0}) {
3022 for (const auto inplace : {false, true}) {
3023 Threshold model{ThresholdOptions(threshold, value).inplace(inplace)};
3024 auto x = torch::linspace(-3.0, 3.0, 61);
3025 x.resize_({size, size, size});
3026 auto x_orig = x.clone();
3027 auto y_exp =
3028 (x_orig <= threshold) * value + (x_orig > threshold) * x_orig;
3029 auto y = model(x);
3030
3031 ASSERT_EQ(y.ndimension(), 3);
3032 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
3033 ASSERT_TRUE(torch::allclose(y, y_exp));
3034 if (inplace) {
3035 ASSERT_TRUE(torch::allclose(x, y_exp));
3036 }
3037 }
3038 }
3039 }
3040 }
3041
TEST_F(ModulesTest,Upsampling1D)3042 TEST_F(ModulesTest, Upsampling1D) {
3043 {
3044 Upsample model(UpsampleOptions()
3045 .size(std::vector<int64_t>({4}))
3046 .mode(torch::kNearest));
3047 auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3048 auto output = model->forward(input);
3049 auto expected = torch::ones({1, 1, 4});
3050 auto s = output.sum();
3051 s.backward();
3052
3053 ASSERT_EQ(s.ndimension(), 0);
3054 ASSERT_TRUE(output.allclose(expected));
3055 }
3056 {
3057 for (const auto align_corners : {true, false}) {
3058 // test float scale factor up & down sampling
3059 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3060 Upsample model(UpsampleOptions()
3061 .scale_factor(std::vector<double>({scale_factor}))
3062 .mode(torch::kLinear)
3063 .align_corners(align_corners));
3064 auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3065 auto output = model->forward(input);
3066 auto expected_size =
3067 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3068 auto expected = torch::ones({1, 1, expected_size});
3069 auto s = output.sum();
3070 s.backward();
3071
3072 ASSERT_EQ(s.ndimension(), 0);
3073 ASSERT_TRUE(output.allclose(expected));
3074 }
3075 }
3076 }
3077 {
3078 // linear (1D) upsampling spatial invariance
3079 Upsample model(UpsampleOptions()
3080 .scale_factor(std::vector<double>({3}))
3081 .mode(torch::kLinear)
3082 .align_corners(false));
3083 auto input = torch::zeros({1, 1, 9});
3084 input.narrow(2, 0, 4).normal_();
3085 auto output = model->forward(input);
3086 auto expected = model->forward(input.narrow(2, 0, 5));
3087
3088 ASSERT_TRUE(torch::allclose(output.narrow(2, 0, 15), expected));
3089 }
3090 }
3091
TEST_F(ModulesTest,Upsampling2D)3092 TEST_F(ModulesTest, Upsampling2D) {
3093 {
3094 Upsample model(UpsampleOptions()
3095 .size(std::vector<int64_t>({4, 4}))
3096 .mode(torch::kNearest));
3097 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3098 auto output = model->forward(input);
3099 auto expected = torch::ones({1, 1, 4, 4});
3100 auto s = output.sum();
3101 s.backward();
3102
3103 ASSERT_EQ(s.ndimension(), 0);
3104 ASSERT_TRUE(output.allclose(expected));
3105 }
3106 {
3107 for (const auto align_corners : {true, false}) {
3108 // test float scale factor up & down sampling
3109 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3110 Upsample model(
3111 UpsampleOptions()
3112 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3113 .mode(torch::kBilinear)
3114 .align_corners(align_corners));
3115 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3116 auto output = model->forward(input);
3117 auto expected_size =
3118 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3119 auto expected = torch::ones({1, 1, expected_size, expected_size});
3120 auto s = output.sum();
3121 s.backward();
3122
3123 ASSERT_EQ(s.ndimension(), 0);
3124 ASSERT_TRUE(output.allclose(expected));
3125 }
3126 }
3127 }
3128 {
3129 for (const auto align_corners : {true, false}) {
3130 // test float scale factor up & down sampling
3131 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3132 Upsample model(
3133 UpsampleOptions()
3134 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3135 .mode(torch::kBicubic)
3136 .align_corners(align_corners));
3137 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3138 auto output = model->forward(input);
3139 auto expected_size =
3140 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3141 auto expected = torch::ones({1, 1, expected_size, expected_size});
3142 auto s = output.sum();
3143 s.backward();
3144
3145 ASSERT_EQ(s.ndimension(), 0);
3146 ASSERT_TRUE(output.allclose(expected));
3147 }
3148 }
3149 }
3150 }
3151
TEST_F(ModulesTest,Upsampling3D)3152 TEST_F(ModulesTest, Upsampling3D) {
3153 {
3154 Upsample model(UpsampleOptions()
3155 .size(std::vector<int64_t>({4, 4, 4}))
3156 .mode(torch::kNearest));
3157 auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3158 auto output = model->forward(input);
3159 auto expected = torch::ones({1, 1, 4, 4, 4});
3160 auto s = output.sum();
3161 s.backward();
3162
3163 ASSERT_EQ(s.ndimension(), 0);
3164 ASSERT_TRUE(output.allclose(expected));
3165 }
3166 {
3167 for (const auto align_corners : {true, false}) {
3168 // test float scale factor up & down sampling
3169 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3170 Upsample model(UpsampleOptions()
3171 .scale_factor(std::vector<double>(
3172 {scale_factor, scale_factor, scale_factor}))
3173 .mode(torch::kTrilinear)
3174 .align_corners(align_corners));
3175 auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3176 auto output = model->forward(input);
3177 auto expected_size =
3178 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3179 auto expected =
3180 torch::ones({1, 1, expected_size, expected_size, expected_size});
3181 auto s = output.sum();
3182 s.backward();
3183
3184 ASSERT_EQ(s.ndimension(), 0);
3185 ASSERT_TRUE(output.allclose(expected));
3186 }
3187 }
3188 }
3189 }
3190
TEST_F(ModulesTest,CTCLoss)3191 TEST_F(ModulesTest, CTCLoss) {
3192 CTCLoss loss{CTCLossOptions().reduction(torch::kNone)};
3193 const auto target_lengths = torch::tensor({0, 0, 0});
3194 const auto input_lengths = torch::tensor({50, 50, 50});
3195 const auto targets =
3196 torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
3197 const auto log_probs =
3198 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
3199 const auto output =
3200 loss->forward(log_probs, targets, input_lengths, target_lengths);
3201 ASSERT_TRUE(output.ge(0).all().item<bool>());
3202 ASSERT_TRUE(torch::allclose(
3203 -log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
3204 }
3205
TEST_F(ModulesTest,PoissonNLLLoss)3206 TEST_F(ModulesTest, PoissonNLLLoss) {
3207 const auto input = torch::tensor({0.5, 1.5, 2.5});
3208 const auto target = torch::tensor({1., 2., 3.});
3209 const auto component_wise_loss = torch::exp(input) - target * input;
3210 {
3211 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)};
3212 ASSERT_TRUE(
3213 torch::allclose(component_wise_loss, loss->forward(input, target)));
3214 }
3215 {
3216 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)};
3217 ASSERT_TRUE(torch::allclose(
3218 torch::sum(component_wise_loss), loss->forward(input, target)));
3219 }
3220 {
3221 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)};
3222 ASSERT_TRUE(torch::allclose(
3223 torch::mean(component_wise_loss), loss->forward(input, target)));
3224 }
3225 }
3226
TEST_F(ModulesTest,MarginRankingLoss)3227 TEST_F(ModulesTest, MarginRankingLoss) {
3228 {
3229 MarginRankingLoss loss;
3230 const auto input1 = torch::randn(15) * 10;
3231 const auto input2 = torch::randn(15) * 10;
3232 const auto target = torch::randn(15).sign();
3233 ASSERT_TRUE(torch::allclose(
3234 loss->forward(input1, input2, target),
3235 (-target * (input1 - input2)).clamp(0).mean()));
3236 }
3237 {
3238 MarginRankingLoss loss{
3239 MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)};
3240 const auto input1 = torch::randn(15) * 10;
3241 const auto input2 = torch::randn(15) * 10;
3242 const auto target = torch::randn(15).sign();
3243 const auto margin = 0.5;
3244 ASSERT_TRUE(torch::allclose(
3245 loss->forward(input1, input2, target),
3246 (-target * (input1 - input2) + margin).clamp(0).sum()));
3247 }
3248 {
3249 MarginRankingLoss loss{
3250 MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)};
3251 const auto input1 = torch::randn(15) * 10;
3252 const auto input2 = torch::randn(15) * 10;
3253 const auto target = torch::randn(15).sign();
3254 const auto margin = 0.5;
3255 ASSERT_TRUE(torch::allclose(
3256 loss->forward(input1, input2, target),
3257 (-target * (input1 - input2) + margin).clamp(0).mean()));
3258 }
3259 }
3260
TEST_F(ModulesTest,BCEWithLogitsLoss)3261 TEST_F(ModulesTest, BCEWithLogitsLoss) {
3262 { // test BCE with logits raises if target and input are different size
3263 {
3264 const auto target = torch::rand(5);
3265 const auto input = torch::rand({5, 1});
3266 ASSERT_THROWS_WITH(
3267 BCEWithLogitsLoss()(input, target), "must be the same as input size");
3268 }
3269 {
3270 const auto target = torch::rand({5, 1});
3271 const auto input = torch::rand(5);
3272 ASSERT_THROWS_WITH(
3273 BCEWithLogitsLoss()(input, target), "must be the same as input size");
3274 }
3275 }
3276 { // test BCE with logits gives same result as sigmoid and bce loss
3277 auto sigmoid = Sigmoid();
3278
3279 auto target = torch::rand({64, 4});
3280 auto output = torch::rand({64, 4}) - 0.5;
3281
3282 ASSERT_TRUE(torch::allclose(
3283 BCEWithLogitsLoss()(output, target),
3284 BCELoss()(sigmoid(output), target)));
3285
3286 auto weight = torch::rand(4);
3287 ASSERT_TRUE(torch::allclose(
3288 BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3289 output, target),
3290 BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3291
3292 target = torch::zeros({4, 1}, torch::kFloat);
3293 output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3294
3295 ASSERT_TRUE(torch::allclose(
3296 BCEWithLogitsLoss()(output, target),
3297 BCELoss()(sigmoid(output), target)));
3298
3299 ASSERT_TRUE(torch::allclose(
3300 BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))(
3301 output, target),
3302 BCELoss(BCELossOptions().reduction(torch::kNone))(
3303 sigmoid(output), target)));
3304
3305 weight = torch::rand({1}, torch::kFloat);
3306 ASSERT_TRUE(torch::allclose(
3307 BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3308 output, target),
3309 BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3310 }
3311 { // test BCE with logits has correct grad at zero
3312 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3313 const auto target = torch::zeros({3, 1});
3314 BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))(
3315 output, target)
3316 .backward();
3317 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3318 ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3319 }
3320 { // test BCE with logits broadcasts weights
3321 const auto target = torch::rand({16, 4});
3322 const auto output = torch::rand({16, 4}) - 0.5;
3323
3324 auto weight = torch::rand(4);
3325 auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3326 output, target);
3327
3328 weight = weight.expand({16, 4}).contiguous();
3329 auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3330 output, target);
3331
3332 ASSERT_TRUE(torch::allclose(out1, out2));
3333
3334 weight = torch::rand({16, 1});
3335 out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3336 output, target);
3337
3338 weight = weight.expand({16, 4}).contiguous();
3339 out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3340 output, target);
3341
3342 ASSERT_TRUE(torch::allclose(out1, out2));
3343 }
3344 { // test BCE with logits ones in pos weights are the same as none
3345 const auto target = torch::rand({64, 4});
3346 const auto output = torch::rand({64, 4}) - 0.5;
3347 const auto pos_weight = torch::ones({64, 4});
3348
3349 ASSERT_TRUE(torch::allclose(
3350 BCEWithLogitsLoss()(output, target),
3351 BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))(
3352 output, target)));
3353 }
3354 { // test BCE with logits broadcasts pos weights
3355 const auto target = torch::rand({64, 4});
3356 const auto output = torch::rand({64, 4}) - 0.5;
3357 const auto pos_weight = torch::rand(4);
3358 const auto out1 = BCEWithLogitsLoss(
3359 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3360
3361 const auto pos_weight1 = pos_weight.expand({1, 4});
3362 const auto out2 = BCEWithLogitsLoss(
3363 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3364
3365 const auto pos_weight2 = pos_weight.expand({64, 4});
3366 const auto out3 = BCEWithLogitsLoss(
3367 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3368
3369 ASSERT_TRUE(torch::allclose(out1, out2));
3370 ASSERT_TRUE(torch::allclose(out1, out3));
3371 }
3372 { // test BCE with logits with pos weight has correct grad at zero
3373 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3374 const auto target = torch::zeros({3, 1});
3375 const auto pos_weight = torch::ones({3, 1});
3376 BCEWithLogitsLoss(BCEWithLogitsLossOptions()
3377 .pos_weight(pos_weight)
3378 .reduction(torch::kSum))(output, target)
3379 .backward();
3380 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3381 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3382 const auto grad = output.grad();
3383 ASSERT_TRUE(torch::allclose(grad, expected_grad));
3384 }
3385 { // test BCE with logits stability
3386 const auto output = torch::tensor({0., -120.});
3387 const auto target = torch::tensor({0., 1.});
3388 const auto pos_weight = torch::tensor({1., 1.});
3389
3390 const auto out1 = BCEWithLogitsLoss()(output, target);
3391 ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3392
3393 const auto out2 = BCEWithLogitsLoss(
3394 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3395 ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3396 }
3397 }
3398
3399 namespace detail {
3400
3401 namespace F = torch::nn::functional;
3402
_batchmatmul(const torch::Tensor & a,const torch::Tensor & b)3403 torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) {
3404 TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0));
3405 TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1));
3406 auto retval = torch::zeros(
3407 {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32);
3408 for (const auto i : c10::irange(a.size(0))) {
3409 for (const auto j : c10::irange(a.size(1))) {
3410 retval[i][j] = torch::matmul(a[i][j], b[i][j]);
3411 }
3412 }
3413 return retval;
3414 }
3415
_softmax(const torch::Tensor & x)3416 torch::Tensor _softmax(const torch::Tensor& x) {
3417 auto output = torch::zeros(x.sizes());
3418 for (const auto i : c10::irange(x.size(0))) {
3419 for (const auto j : c10::irange(x.size(1))) {
3420 for (const auto k : c10::irange(x.size(2))) {
3421 const auto& x_curr = x[i][j][k];
3422 const auto e_x = torch::exp(x_curr - torch::max(x_curr));
3423 output[i][j][k] = e_x / torch::sum(e_x);
3424 }
3425 }
3426 }
3427 return output;
3428 }
3429
_scaled_dot_attn_ref(const torch::Tensor & Q,const torch::Tensor & K,const torch::Tensor & V,at::IntArrayRef dims,const torch::Tensor & unseen_mask={},const torch::Tensor & key_padding_mask={},bool average_attn_weights=true)3430 std::tuple<torch::Tensor, torch::Tensor> _scaled_dot_attn_ref(
3431 const torch::Tensor& Q,
3432 const torch::Tensor& K,
3433 const torch::Tensor& V,
3434 at::IntArrayRef dims,
3435 const torch::Tensor& unseen_mask = {},
3436 const torch::Tensor& key_padding_mask = {},
3437 bool average_attn_weights = true) {
3438 auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3]));
3439 const auto b1 = QKT.size(0);
3440 const auto b2 = QKT.size(1);
3441 const auto s1 = QKT.size(2);
3442 const auto s2 = QKT.size(3);
3443 if (unseen_mask.defined() || key_padding_mask.defined()) {
3444 for (const auto i : c10::irange(b1)) {
3445 for (const auto j : c10::irange(b2)) {
3446 for (const auto m : c10::irange(s1)) {
3447 for (const auto n : c10::irange(s2)) {
3448 if (unseen_mask.defined() &&
3449 unseen_mask[m][n].item<double>() == 0) {
3450 QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3451 }
3452 if (key_padding_mask.defined() &&
3453 key_padding_mask[i][n].item<double>() != 0) {
3454 QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3455 }
3456 }
3457 }
3458 }
3459 }
3460 }
3461 auto reference = _softmax(QKT);
3462 auto ref_attn_weight = reference;
3463 if (average_attn_weights) {
3464 // NOLINTNEXTLINE(bugprone-argument-comment)
3465 ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2;
3466 }
3467 reference = _batchmatmul(reference, V);
3468 return std::tie(reference, ref_attn_weight);
3469 }
3470
_split_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3471 torch::Tensor _split_heads_ref(
3472 const torch::Tensor& X,
3473 at::IntArrayRef dims,
3474 int nheads,
3475 int d_head) {
3476 auto X_split = X.reshape({dims[0], dims[1], nheads, d_head});
3477 auto X_split_transposed = X_split.permute({0, 2, 1, 3});
3478 return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head});
3479 }
3480
_combine_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3481 torch::Tensor _combine_heads_ref(
3482 const torch::Tensor& X,
3483 at::IntArrayRef dims,
3484 int nheads,
3485 int d_head) {
3486 auto X_transposed = X.permute({0, 2, 1, 3});
3487 auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head});
3488 return reference;
3489 }
3490
_fc(torch::Tensor X,torch::Tensor X_weight,torch::Tensor X_bias)3491 torch::Tensor _fc(
3492 torch::Tensor X,
3493 torch::Tensor X_weight,
3494 torch::Tensor X_bias) {
3495 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3496 auto X_fc_b = X_bias;
3497 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3498 auto X_fc_w = X_weight;
3499 return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b;
3500 }
3501
_multihead_attn_test_helper(bool add_key_padding_mask=false,bool add_bias_kv=false,bool add_zero_attn=false,bool saved_kv=false,bool same_embed_dim=false,bool average_attn_weights=true)3502 void _multihead_attn_test_helper(
3503 bool add_key_padding_mask = false,
3504 bool add_bias_kv = false,
3505 bool add_zero_attn = false,
3506 bool saved_kv = false,
3507 bool same_embed_dim = false,
3508 bool average_attn_weights = true) {
3509 std::random_device device;
3510 std::mt19937 generator(device());
3511 std::uniform_int_distribution<int> d_2_10(2, 10);
3512 std::uniform_int_distribution<int> d_3_10(3, 10);
3513 bool registration_checked = false;
3514 for (const auto i : c10::irange(100)) {
3515 (void)i; // Suppress unused variable warning
3516 const auto batch_sz = d_2_10(generator);
3517 const auto seq_len = d_2_10(generator);
3518 const auto d_head = d_3_10(generator);
3519 const auto nheads = d_3_10(generator);
3520 const auto d_model = d_head * nheads;
3521 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3522 int kv_dim;
3523 if (same_embed_dim) {
3524 kv_dim = d_model;
3525 } else {
3526 std::uniform_int_distribution<int> d(5, 20);
3527 kv_dim = d(generator);
3528 while (kv_dim == d_model) {
3529 kv_dim = d(generator);
3530 }
3531 }
3532 std::vector<int64_t> dims{batch_sz, seq_len, kv_dim};
3533 torch::Tensor saved_k;
3534 torch::Tensor saved_k_tensor;
3535 torch::Tensor saved_v;
3536 torch::Tensor saved_v_tensor;
3537 if (saved_kv) {
3538 saved_k = torch::rand({batch_sz * nheads, seq_len, d_head});
3539 saved_k_tensor = saved_k;
3540 saved_v = torch::rand({batch_sz * nheads, seq_len, d_head});
3541 saved_v_tensor = saved_v;
3542 }
3543 torch::Tensor key_padding_mask;
3544 torch::Tensor key_padding_mask_tensor;
3545 if (add_key_padding_mask) {
3546 const auto seq_mask = torch::randint(0, 2, {1, seq_len});
3547 key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1;
3548 key_padding_mask_tensor = key_padding_mask;
3549 }
3550 const auto decoder_state = torch::rand({batch_sz, d_model});
3551 const torch::Tensor K = torch::rand(dims);
3552 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3553 const torch::Tensor V = K;
3554 const torch::Tensor Q =
3555 decoder_state.clone().resize_({batch_sz, 1, d_model});
3556 auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat);
3557 const torch::Tensor attn_mask_tensor = attn_mask.clone();
3558 attn_mask_tensor.masked_fill_(
3559 attn_mask_tensor == 0, -std::numeric_limits<double>::infinity());
3560 attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0));
3561
3562 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3563 const torch::Tensor decoder_state_tensor = decoder_state;
3564 const torch::Tensor source_hid_tensor = K.transpose(0, 1);
3565
3566 const auto options = MultiheadAttentionOptions(d_model, nheads)
3567 .add_bias_kv(add_bias_kv)
3568 .add_zero_attn(add_zero_attn)
3569 .kdim(kv_dim)
3570 .vdim(kv_dim);
3571 const auto multihead_attn_module = MultiheadAttention(options);
3572
3573 if (!registration_checked) {
3574 // make sure parameters are all registered correctly
3575 auto named_parameters = multihead_attn_module->named_parameters();
3576 if (same_embed_dim) {
3577 ASSERT_TRUE(named_parameters.contains("in_proj_weight"));
3578 } else {
3579 ASSERT_TRUE(named_parameters.contains("q_proj_weight"));
3580 ASSERT_TRUE(named_parameters.contains("k_proj_weight"));
3581 ASSERT_TRUE(named_parameters.contains("v_proj_weight"));
3582 }
3583 if (add_bias_kv) {
3584 ASSERT_TRUE(named_parameters.contains("bias_k"));
3585 ASSERT_TRUE(named_parameters.contains("bias_v"));
3586 }
3587 // make sure sub modules are all registered correctly
3588 auto submodules = multihead_attn_module->named_children();
3589 ASSERT_TRUE(submodules.contains("out_proj"));
3590 registration_checked = true;
3591 }
3592
3593 torch::Tensor bias_k;
3594 torch::Tensor bias_v;
3595 if (add_bias_kv) {
3596 bias_k = multihead_attn_module->bias_k.detach();
3597 bias_v = multihead_attn_module->bias_v.detach();
3598 } else {
3599 bias_k.reset();
3600 bias_v.reset();
3601 }
3602
3603 torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1);
3604 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3605 torch::Tensor _V = source_hid_tensor;
3606 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3607 torch::Tensor _K = source_hid_tensor;
3608
3609 torch::Tensor result;
3610 torch::Tensor result_weight;
3611 if (multihead_attn_module->_qkv_same_embed_dim) {
3612 std::tie(result, result_weight) = F::multi_head_attention_forward(
3613 _Q,
3614 _K,
3615 _V,
3616 F::MultiheadAttentionForwardFuncOptions(
3617 /*embed_dim_to_check=*/d_model,
3618 /*num_heads=*/nheads,
3619 /*in_proj_weight=*/multihead_attn_module->in_proj_weight,
3620 /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3621 /*bias_k=*/multihead_attn_module->bias_k,
3622 /*bias_v=*/multihead_attn_module->bias_v,
3623 /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3624 /*dropout_p=*/multihead_attn_module->options.dropout(),
3625 /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3626 /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3627 .training(multihead_attn_module->is_training())
3628 .key_padding_mask(key_padding_mask_tensor)
3629 .need_weights(true)
3630 .attn_mask(attn_mask_tensor)
3631 .static_k(saved_k_tensor)
3632 .static_v(saved_v_tensor)
3633 .average_attn_weights(average_attn_weights));
3634 } else {
3635 std::tie(result, result_weight) = F::multi_head_attention_forward(
3636 _Q,
3637 _K,
3638 _V,
3639 F::MultiheadAttentionForwardFuncOptions(
3640 /*embed_dim_to_check=*/d_model,
3641 /*num_heads=*/nheads,
3642 /*in_proj_weight=*/{},
3643 /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3644 /*bias_k=*/multihead_attn_module->bias_k,
3645 /*bias_v=*/multihead_attn_module->bias_v,
3646 /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3647 /*dropout_p=*/multihead_attn_module->options.dropout(),
3648 /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3649 /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3650 .training(multihead_attn_module->is_training())
3651 .key_padding_mask(key_padding_mask_tensor)
3652 .need_weights(true)
3653 .attn_mask(attn_mask_tensor)
3654 .use_separate_proj_weight(true)
3655 .q_proj_weight(multihead_attn_module->q_proj_weight)
3656 .k_proj_weight(multihead_attn_module->k_proj_weight)
3657 .v_proj_weight(multihead_attn_module->v_proj_weight)
3658 .static_k(saved_k_tensor)
3659 .static_v(saved_v_tensor)
3660 .average_attn_weights(average_attn_weights));
3661 }
3662 result = result.squeeze(0).detach();
3663 torch::Tensor q_proj_weight;
3664 torch::Tensor k_proj_weight;
3665 torch::Tensor v_proj_weight;
3666 if (multihead_attn_module->_qkv_same_embed_dim) {
3667 q_proj_weight =
3668 multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model);
3669 k_proj_weight = multihead_attn_module->in_proj_weight.slice(
3670 /*dim=*/0, d_model, (d_model * 2));
3671 v_proj_weight =
3672 multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2));
3673 } else {
3674 q_proj_weight = multihead_attn_module->q_proj_weight;
3675 k_proj_weight = multihead_attn_module->k_proj_weight;
3676 v_proj_weight = multihead_attn_module->v_proj_weight;
3677 }
3678 auto Q_fc =
3679 _fc(Q,
3680 q_proj_weight,
3681 multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model));
3682 auto K_fc =
3683 _fc(K,
3684 k_proj_weight,
3685 multihead_attn_module->in_proj_bias.slice(
3686 /*dim=*/0, d_model, (d_model * 2)));
3687 auto V_fc = _fc(
3688 V,
3689 v_proj_weight,
3690 multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2)));
3691
3692 if (add_bias_kv) {
3693 K_fc = torch::cat(
3694 {K_fc,
3695 bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)},
3696 /*dim=*/1);
3697 V_fc = torch::cat(
3698 {V_fc,
3699 bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)},
3700 /*dim=*/1);
3701 if (attn_mask.defined()) {
3702 attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3703 }
3704 if (key_padding_mask.defined()) {
3705 key_padding_mask = torch::cat(
3706 {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3707 /*dim=*/1);
3708 }
3709 dims[1] += 1;
3710 }
3711 const auto Q_split =
3712 _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head);
3713 torch::Tensor K_split;
3714 if (saved_k.defined()) {
3715 K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head});
3716 } else {
3717 K_split = _split_heads_ref(K_fc, dims, nheads, d_head);
3718 }
3719 torch::Tensor V_split;
3720 if (saved_v.defined()) {
3721 V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head});
3722 } else {
3723 V_split = _split_heads_ref(V_fc, dims, nheads, d_head);
3724 }
3725 if (add_zero_attn) {
3726 dims[1] += 1;
3727 K_split = torch::cat(
3728 {K_split,
3729 torch::zeros(
3730 {K_split.size(0), K_split.size(1), 1, K_split.size(3)})},
3731 /*dim=*/2);
3732 V_split = torch::cat(
3733 {V_split,
3734 torch::zeros(
3735 {V_split.size(0), V_split.size(1), 1, V_split.size(3)})},
3736 /*dim=*/2);
3737 if (attn_mask.defined()) {
3738 attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3739 }
3740 if (key_padding_mask.defined()) {
3741 key_padding_mask = torch::cat(
3742 {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3743 /*dim=*/1);
3744 }
3745 }
3746 auto [attn_heads, ref_attn_weight] = _scaled_dot_attn_ref(
3747 Q_split,
3748 K_split,
3749 V_split,
3750 Q_split.sizes(),
3751 attn_mask,
3752 key_padding_mask,
3753 average_attn_weights);
3754 const auto combined_attn_heads =
3755 _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head);
3756 auto reference =
3757 _fc(combined_attn_heads,
3758 multihead_attn_module->out_proj->weight,
3759 multihead_attn_module->out_proj->bias);
3760 // NOLINTNEXTLINE(bugprone-argument-comment)
3761 reference = torch::squeeze(reference, /*axis=*/1);
3762
3763 // result = reference
3764 ASSERT_EQ(result.sizes(), std::vector<int64_t>({batch_sz, d_model}));
3765 ASSERT_TRUE(
3766 torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true));
3767
3768 // result_weight = ref_attn_weight
3769 result_weight = result_weight.detach();
3770 ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes());
3771 ASSERT_TRUE(torch::allclose(
3772 result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true));
3773 }
3774 }
3775 } // namespace detail
3776
TEST_F(ModulesTest,MultiheadAttention)3777 TEST_F(ModulesTest, MultiheadAttention) {
3778 using namespace ::detail;
3779
3780 for (auto average_attn_weights : {false, true}) {
3781 // test_multihead_attn_add_zero_attn
3782 _multihead_attn_test_helper(
3783 /*add_key_padding_mask=*/false,
3784 /*add_bias_kv=*/false,
3785 /*add_zero_attn=*/true,
3786 /*saved_kv=*/false,
3787 /*same_embed_dim=*/false,
3788 /*average_attn_weights=*/average_attn_weights);
3789
3790 // test_multihead_attn_add_bias_kv
3791 _multihead_attn_test_helper(
3792 /*add_key_padding_mask=*/false,
3793 /*add_bias_kv=*/true,
3794 /*add_zero_attn=*/false,
3795 /*saved_kv=*/false,
3796 /*same_embed_dim=*/false,
3797 /*average_attn_weights=*/average_attn_weights);
3798
3799 // test_multihead_attn_no_masking():
3800 _multihead_attn_test_helper();
3801
3802 // test_multihead_attn_key_padding_mask
3803 _multihead_attn_test_helper(
3804 /*add_key_padding_mask=*/true,
3805 /*add_bias_kv=*/false,
3806 /*add_zero_attn=*/false,
3807 /*saved_kv=*/false,
3808 /*same_embed_dim=*/false,
3809 /*average_attn_weights=*/average_attn_weights);
3810
3811 // test_multihead_attn_saved_kv
3812 _multihead_attn_test_helper(
3813 /*add_key_padding_mask=*/false,
3814 /*add_bias_kv=*/false,
3815 /*add_zero_attn=*/false,
3816 /*saved_kv=*/true,
3817 /*same_embed_dim=*/false,
3818 /*average_attn_weights=*/average_attn_weights);
3819
3820 // test_multihead_attn_add_bias_kv_zero_attn
3821 _multihead_attn_test_helper(
3822 /*add_key_padding_mask=*/true,
3823 /*add_bias_kv=*/true,
3824 /*add_zero_attn=*/true,
3825 /*saved_kv=*/false,
3826 /*same_embed_dim=*/false,
3827 /*average_attn_weights=*/average_attn_weights);
3828
3829 // test_multihead_attn_all_arguments1
3830 _multihead_attn_test_helper(
3831 /*add_key_padding_mask=*/true,
3832 /*add_bias_kv=*/false,
3833 /*add_zero_attn=*/true,
3834 /*saved_kv=*/true,
3835 /*same_embed_dim=*/false,
3836 /*average_attn_weights=*/average_attn_weights);
3837
3838 ASSERT_THROWS_WITH(
3839 // test_multihead_attn_all_arguments2
3840 _multihead_attn_test_helper(
3841 /*add_key_padding_mask=*/true,
3842 /*add_bias_kv=*/true,
3843 /*add_zero_attn=*/true,
3844 /*saved_kv=*/true,
3845 /*same_embed_dim=*/false,
3846 /*average_attn_weights=*/average_attn_weights),
3847 "bias cannot be added to static key");
3848
3849 // test_multihead_attn_all_arguments3
3850 _multihead_attn_test_helper(
3851 /*add_key_padding_mask=*/true,
3852 /*add_bias_kv=*/false,
3853 /*add_zero_attn=*/true,
3854 /*saved_kv=*/true,
3855 /*same_embed_dim=*/true,
3856 /*average_attn_weights=*/average_attn_weights);
3857 }
3858 }
3859
TEST_F(ModulesTest,PrettyPrintIdentity)3860 TEST_F(ModulesTest, PrettyPrintIdentity) {
3861 ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
3862 }
3863
TEST_F(ModulesTest,PrettyPrintFlatten)3864 TEST_F(ModulesTest, PrettyPrintFlatten) {
3865 ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)");
3866 ASSERT_EQ(
3867 c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
3868 "torch::nn::Flatten(start_dim=2, end_dim=4)");
3869 }
3870
TEST_F(ModulesTest,PrettyPrintUnflatten)3871 TEST_F(ModulesTest, PrettyPrintUnflatten) {
3872 ASSERT_EQ(
3873 c10::str(Unflatten(UnflattenOptions(0, {2, 2}))),
3874 "torch::nn::Unflatten(dim=0, unflattened_size={2, 2})");
3875 ASSERT_EQ(
3876 c10::str(Unflatten(UnflattenOptions(
3877 "B",
3878 {std::pair<std::string, int64_t>{"B1", 2},
3879 std::pair<std::string, int64_t>{"B2", 2}}))),
3880 "torch::nn::Unflatten(dim=\"B\", unflattened_size={{\"B1\", 2}, {\"B2\", 2}})");
3881 }
3882
TEST_F(ModulesTest,ReflectionPad1d)3883 TEST_F(ModulesTest, ReflectionPad1d) {
3884 {
3885 ReflectionPad1d m(ReflectionPad1dOptions(2));
3886 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3887 auto output = m(input);
3888 auto expected = torch::tensor(
3889 {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}},
3890 torch::kFloat);
3891 ASSERT_TRUE(output.allclose(expected));
3892 }
3893 {
3894 ReflectionPad1d m(ReflectionPad1dOptions({3, 1}));
3895 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3896 auto output = m(input);
3897 auto expected = torch::tensor(
3898 {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}},
3899 torch::kFloat);
3900 ASSERT_TRUE(output.allclose(expected));
3901 }
3902 }
3903
TEST_F(ModulesTest,ReflectionPad2d)3904 TEST_F(ModulesTest, ReflectionPad2d) {
3905 {
3906 ReflectionPad2d m(ReflectionPad2dOptions(2));
3907 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3908 auto output = m(input);
3909 auto expected = torch::tensor(
3910 {{{{8., 7., 6., 7., 8., 7., 6.},
3911 {5., 4., 3., 4., 5., 4., 3.},
3912 {2., 1., 0., 1., 2., 1., 0.},
3913 {5., 4., 3., 4., 5., 4., 3.},
3914 {8., 7., 6., 7., 8., 7., 6.},
3915 {5., 4., 3., 4., 5., 4., 3.},
3916 {2., 1., 0., 1., 2., 1., 0.}}}},
3917 torch::kFloat);
3918 ASSERT_TRUE(output.allclose(expected));
3919 }
3920 {
3921 ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0}));
3922 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3923 auto output = m(input);
3924 auto expected = torch::tensor(
3925 {{{{7., 6., 7., 8., 7.},
3926 {4., 3., 4., 5., 4.},
3927 {1., 0., 1., 2., 1.},
3928 {4., 3., 4., 5., 4.},
3929 {7., 6., 7., 8., 7.}}}},
3930 torch::kFloat);
3931 ASSERT_TRUE(output.allclose(expected));
3932 }
3933 }
3934
TEST_F(ModulesTest,ReflectionPad3d)3935 TEST_F(ModulesTest, ReflectionPad3d) {
3936 {
3937 ReflectionPad3d m(ReflectionPad3dOptions(1));
3938 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
3939 auto output = m(input);
3940 auto expected = torch::tensor(
3941 {{{{{7., 6., 7., 6.},
3942 {5., 4., 5., 4.},
3943 {7., 6., 7., 6.},
3944 {5., 4., 5., 4.}},
3945 {{3., 2., 3., 2.},
3946 {1., 0., 1., 0.},
3947 {3., 2., 3., 2.},
3948 {1., 0., 1., 0.}},
3949 {{7., 6., 7., 6.},
3950 {5., 4., 5., 4.},
3951 {7., 6., 7., 6.},
3952 {5., 4., 5., 4.}},
3953 {{3., 2., 3., 2.},
3954 {1., 0., 1., 0.},
3955 {3., 2., 3., 2.},
3956 {1., 0., 1., 0.}}}}},
3957 torch::kFloat);
3958 ASSERT_TRUE(output.allclose(expected));
3959 }
3960 {
3961 ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2}));
3962 auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2});
3963 auto output = m(input);
3964 auto expected = torch::tensor(
3965 {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3966 {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}},
3967 {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3968 {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3969 {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}},
3970 {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3971 {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}},
3972 torch::kFloat);
3973 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 3, 3}));
3974 ASSERT_TRUE(output.allclose(expected));
3975 }
3976 }
TEST_F(ModulesTest,ReplicationPad1d)3977 TEST_F(ModulesTest, ReplicationPad1d) {
3978 {
3979 ReplicationPad1d m(ReplicationPad1dOptions(2));
3980 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3981 auto output = m(input);
3982 auto expected = torch::tensor(
3983 {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}},
3984 torch::kFloat);
3985 ASSERT_TRUE(output.allclose(expected));
3986 }
3987 {
3988 ReplicationPad1d m(ReplicationPad1dOptions({3, 1}));
3989 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3990 auto output = m(input);
3991 auto expected = torch::tensor(
3992 {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}},
3993 torch::kFloat);
3994 ASSERT_TRUE(output.allclose(expected));
3995 }
3996 }
3997
TEST_F(ModulesTest,ReplicationPad2d)3998 TEST_F(ModulesTest, ReplicationPad2d) {
3999 {
4000 ReplicationPad2d m(ReplicationPad2dOptions(2));
4001 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4002 auto output = m(input);
4003 auto expected = torch::tensor(
4004 {{{{0., 0., 0., 1., 2., 2., 2.},
4005 {0., 0., 0., 1., 2., 2., 2.},
4006 {0., 0., 0., 1., 2., 2., 2.},
4007 {3., 3., 3., 4., 5., 5., 5.},
4008 {6., 6., 6., 7., 8., 8., 8.},
4009 {6., 6., 6., 7., 8., 8., 8.},
4010 {6., 6., 6., 7., 8., 8., 8.}}}},
4011 torch::kFloat);
4012 ASSERT_TRUE(output.allclose(expected));
4013 }
4014 {
4015 ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0}));
4016 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4017 auto output = m(input);
4018 auto expected = torch::tensor(
4019 {{{{0., 0., 1., 2., 2.},
4020 {0., 0., 1., 2., 2.},
4021 {0., 0., 1., 2., 2.},
4022 {3., 3., 4., 5., 5.},
4023 {6., 6., 7., 8., 8.}}}},
4024 torch::kFloat);
4025 ASSERT_TRUE(output.allclose(expected));
4026 }
4027 }
4028
TEST_F(ModulesTest,ReplicationPad3d)4029 TEST_F(ModulesTest, ReplicationPad3d) {
4030 {
4031 ReplicationPad3d m(ReplicationPad3dOptions(1));
4032 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4033 auto output = m(input);
4034 auto expected = torch::tensor(
4035 {{{{{0., 0., 1., 1.},
4036 {0., 0., 1., 1.},
4037 {2., 2., 3., 3.},
4038 {2., 2., 3., 3.}},
4039 {{0., 0., 1., 1.},
4040 {0., 0., 1., 1.},
4041 {2., 2., 3., 3.},
4042 {2., 2., 3., 3.}},
4043 {{4., 4., 5., 5.},
4044 {4., 4., 5., 5.},
4045 {6., 6., 7., 7.},
4046 {6., 6., 7., 7.}},
4047 {{4., 4., 5., 5.},
4048 {4., 4., 5., 5.},
4049 {6., 6., 7., 7.},
4050 {6., 6., 7., 7.}}}}},
4051 torch::kFloat);
4052 ASSERT_TRUE(output.allclose(expected));
4053 }
4054 {
4055 ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}));
4056 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4057 auto output = m(input);
4058 auto expected = torch::tensor(
4059 {{{{{0., 0., 1., 1., 1.},
4060 {0., 0., 1., 1., 1.},
4061 {2., 2., 3., 3., 3.},
4062 {2., 2., 3., 3., 3.},
4063 {2., 2., 3., 3., 3.}},
4064 {{0., 0., 1., 1., 1.},
4065 {0., 0., 1., 1., 1.},
4066 {2., 2., 3., 3., 3.},
4067 {2., 2., 3., 3., 3.},
4068 {2., 2., 3., 3., 3.}},
4069 {{4., 4., 5., 5., 5.},
4070 {4., 4., 5., 5., 5.},
4071 {6., 6., 7., 7., 7.},
4072 {6., 6., 7., 7., 7.},
4073 {6., 6., 7., 7., 7.}},
4074 {{4., 4., 5., 5., 5.},
4075 {4., 4., 5., 5., 5.},
4076 {6., 6., 7., 7., 7.},
4077 {6., 6., 7., 7., 7.},
4078 {6., 6., 7., 7., 7.}},
4079 {{4., 4., 5., 5., 5.},
4080 {4., 4., 5., 5., 5.},
4081 {6., 6., 7., 7., 7.},
4082 {6., 6., 7., 7., 7.},
4083 {6., 6., 7., 7., 7.}}}}},
4084 torch::kFloat);
4085 ASSERT_TRUE(output.allclose(expected));
4086 }
4087 }
4088
TEST_F(ModulesTest,ZeroPad1d)4089 TEST_F(ModulesTest, ZeroPad1d) {
4090 {
4091 ZeroPad1d m(ZeroPad1dOptions(2));
4092 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4093 auto output = m(input);
4094 auto expected = torch::tensor(
4095 {{{0., 0., 0., 1., 2., 3., 0., 0.}, {0., 0., 4., 5., 6., 7., 0., 0.}}},
4096 torch::kFloat);
4097 ASSERT_TRUE(output.allclose(expected));
4098 }
4099 {
4100 ZeroPad1d m(ZeroPad1dOptions({3, 1}));
4101 auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4102 auto output = m(input);
4103 auto expected = torch::tensor(
4104 {{{0., 0., 0., 0., 1., 2., 0.}, {0., 0., 0., 3., 4., 5., 0.}}},
4105 torch::kFloat);
4106 ASSERT_TRUE(output.allclose(expected));
4107 }
4108 }
4109
TEST_F(ModulesTest,ZeroPad2d)4110 TEST_F(ModulesTest, ZeroPad2d) {
4111 {
4112 ZeroPad2d m(ZeroPad2dOptions(2));
4113 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4114 auto output = m(input);
4115 auto expected = torch::tensor(
4116 {{{{0., 0., 0., 0., 0., 0., 0.},
4117 {0., 0., 0., 0., 0., 0., 0.},
4118 {0., 0., 0., 1., 2., 0., 0.},
4119 {0., 0., 3., 4., 5., 0., 0.},
4120 {0., 0., 6., 7., 8., 0., 0.},
4121 {0., 0., 0., 0., 0., 0., 0.},
4122 {0., 0., 0., 0., 0., 0., 0.}}}},
4123 torch::kFloat);
4124 ASSERT_TRUE(output.allclose(expected));
4125 }
4126 {
4127 ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0}));
4128 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4129 auto output = m(input);
4130 auto expected = torch::tensor(
4131 {{{{0., 0., 0., 0., 0.},
4132 {0., 0., 0., 0., 0.},
4133 {0., 0., 1., 2., 0.},
4134 {0., 3., 4., 5., 0.},
4135 {0., 6., 7., 8., 0.}}}},
4136 torch::kFloat);
4137 ASSERT_TRUE(output.allclose(expected));
4138 }
4139 }
4140
TEST_F(ModulesTest,ZeroPad3d)4141 TEST_F(ModulesTest, ZeroPad3d) {
4142 {
4143 ZeroPad3d m(ZeroPad3dOptions(1));
4144 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4145 auto output = m(input);
4146 auto expected = torch::tensor(
4147 {{{{{0., 0., 0., 0.},
4148 {0., 0., 0., 0.},
4149 {0., 0., 0., 0.},
4150 {0., 0., 0., 0.}},
4151 {{0., 0., 0., 0.},
4152 {0., 0., 1., 0.},
4153 {0., 2., 3., 0.},
4154 {0., 0., 0., 0.}},
4155 {{0., 0., 0., 0.},
4156 {0., 4., 5., 0.},
4157 {0., 6., 7., 0.},
4158 {0., 0., 0., 0.}},
4159 {{0., 0., 0., 0.},
4160 {0., 0., 0., 0.},
4161 {0., 0., 0., 0.},
4162 {0., 0., 0., 0.}}}}},
4163 torch::kFloat);
4164 ASSERT_TRUE(output.allclose(expected));
4165 }
4166 {
4167 ZeroPad3d m(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}));
4168 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4169 auto output = m(input);
4170 auto expected = torch::tensor(
4171 {{{{{0., 0., 0., 0., 0.},
4172 {0., 0., 0., 0., 0.},
4173 {0., 0., 0., 0., 0.},
4174 {0., 0., 0., 0., 0.},
4175 {0., 0., 0., 0., 0.}},
4176 {{0., 0., 0., 0., 0.},
4177 {0., 0., 1., 0., 0.},
4178 {0., 2., 3., 0., 0.},
4179 {0., 0., 0., 0., 0.},
4180 {0., 0., 0., 0., 0.}},
4181 {{0., 0., 0., 0., 0.},
4182 {0., 4., 5., 0., 0.},
4183 {0., 6., 7., 0., 0.},
4184 {0., 0., 0., 0., 0.},
4185 {0., 0., 0., 0., 0.}},
4186 {{0., 0., 0., 0., 0.},
4187 {0., 0., 0., 0., 0.},
4188 {0., 0., 0., 0., 0.},
4189 {0., 0., 0., 0., 0.},
4190 {0., 0., 0., 0., 0.}},
4191 {{0., 0., 0., 0., 0.},
4192 {0., 0., 0., 0., 0.},
4193 {0., 0., 0., 0., 0.},
4194 {0., 0., 0., 0., 0.},
4195 {0., 0., 0., 0., 0.}}}}},
4196 torch::kFloat);
4197 ASSERT_TRUE(output.allclose(expected));
4198 }
4199 }
4200
TEST_F(ModulesTest,ConstantPad1d)4201 TEST_F(ModulesTest, ConstantPad1d) {
4202 {
4203 ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
4204 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4205 auto output = m(input);
4206 auto expected = torch::tensor(
4207 {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000},
4208 {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}},
4209 torch::kFloat);
4210 ASSERT_TRUE(output.allclose(expected));
4211 }
4212 {
4213 ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5));
4214 auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4215 auto output = m(input);
4216 auto expected = torch::tensor(
4217 {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000},
4218 {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}},
4219 torch::kFloat);
4220 ASSERT_TRUE(output.allclose(expected));
4221 }
4222 }
4223
TEST_F(ModulesTest,ConstantPad2d)4224 TEST_F(ModulesTest, ConstantPad2d) {
4225 {
4226 ConstantPad2d m(ConstantPad2dOptions(2, 3.5));
4227 auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4228 auto output = m(input);
4229 auto expected = torch::tensor(
4230 {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4231 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4232 {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4233 {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4234 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4235 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4236 torch::kFloat);
4237 ASSERT_TRUE(output.allclose(expected));
4238 }
4239 {
4240 ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5));
4241 auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4242 auto output = m(input);
4243 auto expected = torch::tensor(
4244 {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4245 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4246 {3.5000, 3.5000, 3.5000, 0.0000, 1.0000},
4247 {3.5000, 3.5000, 3.5000, 2.0000, 3.0000},
4248 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4249 torch::kFloat);
4250 ASSERT_TRUE(output.allclose(expected));
4251 }
4252 }
4253
TEST_F(ModulesTest,ConstantPad3d)4254 TEST_F(ModulesTest, ConstantPad3d) {
4255 {
4256 ConstantPad3d m(ConstantPad3dOptions(1, 3.5));
4257 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4258 auto output = m(input);
4259 auto expected = torch::tensor(
4260 {{{{{3.5000, 3.5000, 3.5000, 3.5000},
4261 {3.5000, 3.5000, 3.5000, 3.5000},
4262 {3.5000, 3.5000, 3.5000, 3.5000},
4263 {3.5000, 3.5000, 3.5000, 3.5000}},
4264 {{3.5000, 3.5000, 3.5000, 3.5000},
4265 {3.5000, 0.0000, 1.0000, 3.5000},
4266 {3.5000, 2.0000, 3.0000, 3.5000},
4267 {3.5000, 3.5000, 3.5000, 3.5000}},
4268 {{3.5000, 3.5000, 3.5000, 3.5000},
4269 {3.5000, 4.0000, 5.0000, 3.5000},
4270 {3.5000, 6.0000, 7.0000, 3.5000},
4271 {3.5000, 3.5000, 3.5000, 3.5000}},
4272 {{3.5000, 3.5000, 3.5000, 3.5000},
4273 {3.5000, 3.5000, 3.5000, 3.5000},
4274 {3.5000, 3.5000, 3.5000, 3.5000},
4275 {3.5000, 3.5000, 3.5000, 3.5000}}}}},
4276 torch::kFloat);
4277 ASSERT_TRUE(output.allclose(expected));
4278 }
4279 {
4280 ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5));
4281 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4282 auto output = m(input);
4283 auto expected = torch::tensor(
4284 {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4285 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4286 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4287 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4288 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4289 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4290 {3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4291 {3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4292 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4293 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4294 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4295 {3.5000, 4.0000, 5.0000, 3.5000, 3.5000},
4296 {3.5000, 6.0000, 7.0000, 3.5000, 3.5000},
4297 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4298 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4299 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4300 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4301 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4302 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4303 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4304 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4305 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4306 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4307 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4308 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}},
4309 torch::kFloat);
4310 ASSERT_TRUE(output.allclose(expected));
4311 }
4312 }
4313
TEST_F(ModulesTest,CrossMapLRN2d)4314 TEST_F(ModulesTest, CrossMapLRN2d) {
4315 /// size 3, default options
4316 auto input =
4317 torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true);
4318 auto expected = torch::tensor(
4319 {{{{0.00000000, 0.99997497, 1.99980010},
4320 {2.99932500, 3.99840070, 4.99687700},
4321 {5.99460600, 6.99143740, 7.98722360}}}},
4322 torch::kFloat32);
4323 auto grad_expected = torch::tensor(
4324 {{{{1.00000000, 0.99992496, 0.99970007},
4325 {0.99932520, 0.99880093, 0.99812720},
4326 {0.99730474, 0.99633380, 0.99521490}}}},
4327 torch::kFloat32);
4328 auto crossmaplrn2d = CrossMapLRN2d(3);
4329 auto output = crossmaplrn2d(input);
4330 output.sum().backward();
4331
4332 ASSERT_TRUE(input.grad().allclose(grad_expected));
4333 ASSERT_TRUE(output.allclose(expected));
4334
4335 /// size change
4336 crossmaplrn2d =
4337 CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1));
4338 output = crossmaplrn2d(input);
4339 expected = torch::tensor(
4340 {{{{0.00000000, 0.99998120, 1.99985000},
4341 {2.99949400, 3.99880050, 4.99765800},
4342 {5.99595300, 6.99357600, 7.99041300}}}},
4343 torch::kFloat32);
4344 ASSERT_TRUE(output.allclose(expected));
4345
4346 /// alpha change
4347 crossmaplrn2d =
4348 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1));
4349 output = crossmaplrn2d(input);
4350 expected = torch::tensor(
4351 {{{{0.00000000, 0.99975010, 1.99800230},
4352 {2.99326750, 3.98407440, 4.96897600},
4353 {5.94656100, 6.91545720, 7.87434340}}}},
4354 torch::kFloat32);
4355 ASSERT_TRUE(output.allclose(expected));
4356
4357 /// beta change
4358 crossmaplrn2d =
4359 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1));
4360 output = crossmaplrn2d(input);
4361 expected = torch::tensor(
4362 {{{{0.00000000, 0.99996830, 1.99974680},
4363 {2.99914500, 3.99797440, 4.99604460},
4364 {5.99316840, 6.98915600, 7.98382000}}}},
4365 torch::kFloat32);
4366 ASSERT_TRUE(output.allclose(expected));
4367
4368 /// k change
4369 crossmaplrn2d =
4370 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2));
4371 output = crossmaplrn2d(input);
4372 expected = torch::tensor(
4373 {{{{0.00000000, 0.59459610, 1.18914770},
4374 {1.78361000, 2.37793870, 2.97208900},
4375 {3.56601700, 4.15967700, 4.75302650}}}},
4376 torch::kFloat32);
4377 ASSERT_TRUE(output.allclose(expected));
4378 }
4379
TEST_F(ModulesTest,RNNCell)4380 TEST_F(ModulesTest, RNNCell) {
4381 torch::manual_seed(0);
4382 auto rnn = RNNCell(1, 2);
4383
4384 auto input = torch::randn({3, 1});
4385 auto hx = torch::randn({3, 2});
4386 auto output = rnn(input, hx);
4387 auto expected =
4388 torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}});
4389 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4390
4391 output = rnn(input);
4392 expected =
4393 torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}});
4394 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4395
4396 input = torch::randn({1});
4397 hx = torch::randn({2});
4398 output = rnn(input, hx);
4399 expected = torch::tensor({0.2808, 0.6505});
4400 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4401
4402 {
4403 auto input = torch::randn({3, 2});
4404 auto hx = torch::randn({3, 2});
4405 ASSERT_THROWS_WITH(
4406 rnn(input, hx), "input has inconsistent input_size: got 2 expected 1");
4407 }
4408
4409 {
4410 auto input = torch::randn({3, 1});
4411 auto hx = torch::randn({3, 1});
4412 ASSERT_THROWS_WITH(
4413 rnn(input, hx),
4414 "hidden0 has inconsistent hidden_size: got 1, expected 2");
4415 }
4416
4417 {
4418 auto input = torch::randn({3, 1, 1, 1, 1});
4419 auto hx = torch::randn({3, 2});
4420 ASSERT_THROWS_WITH(
4421 rnn(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4422 }
4423
4424 {
4425 auto input = torch::randn({3, 1});
4426 auto hx = torch::randn({3, 1, 1, 1, 2});
4427 ASSERT_THROWS_WITH(
4428 rnn(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4429 }
4430 }
4431
TEST_F(ModulesTest,LSTMCell)4432 TEST_F(ModulesTest, LSTMCell) {
4433 torch::manual_seed(0);
4434 auto lstm = LSTMCell(1, 2);
4435
4436 auto input = torch::randn({3, 1});
4437 auto hx = torch::randn({3, 2});
4438 auto cx = torch::randn({3, 2});
4439 auto output = lstm(input, std::make_tuple(hx, cx));
4440 auto output_hx = std::get<0>(output);
4441 auto output_cx = std::get<1>(output);
4442 auto expected_hx =
4443 torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}});
4444 auto expected_cx =
4445 torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}});
4446 ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4447 ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4448
4449 output = lstm(input);
4450 output_hx = std::get<0>(output);
4451 output_cx = std::get<1>(output);
4452 expected_hx =
4453 torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}});
4454 expected_cx =
4455 torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}});
4456 ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4457 ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4458
4459 input = torch::randn({1});
4460 hx = torch::randn({2});
4461 cx = torch::randn({2});
4462 output = lstm(input, std::make_tuple(hx, cx));
4463 output_hx = std::get<0>(output);
4464 output_cx = std::get<1>(output);
4465 expected_hx = torch::tensor({-0.0443, 0.1537});
4466 expected_cx = torch::tensor({-0.1195, 0.2144});
4467 ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4468 ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4469
4470 {
4471 auto input = torch::randn({3, 2});
4472 auto hx = torch::randn({3, 2});
4473 auto cx = torch::randn({3, 2});
4474 ASSERT_THROWS_WITH(
4475 lstm(input, std::make_tuple(hx, cx)),
4476 "input has inconsistent input_size: got 2 expected 1");
4477 }
4478
4479 {
4480 auto input = torch::randn({3, 1});
4481 auto hx = torch::randn({3, 1});
4482 auto cx = torch::randn({3, 2});
4483 ASSERT_THROWS_WITH(
4484 lstm(input, std::make_tuple(hx, cx)),
4485 "hidden0 has inconsistent hidden_size: got 1, expected 2");
4486 }
4487
4488 {
4489 auto input = torch::randn({3, 1});
4490 auto hx = torch::randn({3, 2});
4491 auto cx = torch::randn({3, 1});
4492 ASSERT_THROWS_WITH(
4493 lstm(input, std::make_tuple(hx, cx)),
4494 "hidden1 has inconsistent hidden_size: got 1, expected 2");
4495 }
4496
4497 {
4498 auto input = torch::randn({3, 1, 1, 1, 1});
4499 auto hx = torch::randn({3, 1});
4500 auto cx = torch::randn({3, 1});
4501 ASSERT_THROWS_WITH(
4502 lstm(input, std::make_tuple(hx, cx)),
4503 "Expected input to be 1D or 2D, got 5D instead");
4504 }
4505
4506 {
4507 auto input = torch::randn({3, 1});
4508 auto hx = torch::randn({3, 1, 1, 1, 2});
4509 auto cx = torch::randn({3, 2});
4510 ASSERT_THROWS_WITH(
4511 lstm(input, std::make_tuple(hx, cx)),
4512 "Expected hx[0] to be 1D or 2D, got 5D instead");
4513 }
4514
4515 {
4516 auto input = torch::randn({3, 1});
4517 auto hx = torch::randn({3, 2});
4518 auto cx = torch::randn({3, 1, 1, 1, 2});
4519 ASSERT_THROWS_WITH(
4520 lstm(input, std::make_tuple(hx, cx)),
4521 "Expected hx[1] to be 1D or 2D, got 5D instead");
4522 }
4523 }
4524
TEST_F(ModulesTest,GRUCell)4525 TEST_F(ModulesTest, GRUCell) {
4526 torch::manual_seed(0);
4527 auto gru = GRUCell(1, 2);
4528
4529 auto input = torch::randn({3, 1});
4530 auto hx = torch::randn({3, 2});
4531 auto output = gru(input, hx);
4532 auto expected =
4533 torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}});
4534 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4535
4536 output = gru(input);
4537 expected =
4538 torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}});
4539 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4540
4541 input = torch::randn({1});
4542 hx = torch::randn({2});
4543 output = gru(input, hx);
4544 expected = torch::tensor({-1.0058, -0.3025});
4545 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4546
4547 {
4548 auto input = torch::randn({3, 2});
4549 auto hx = torch::randn({3, 2});
4550 ASSERT_THROWS_WITH(
4551 gru(input, hx), "input has inconsistent input_size: got 2 expected 1");
4552 }
4553
4554 {
4555 auto input = torch::randn({3, 1});
4556 auto hx = torch::randn({3, 1});
4557 ASSERT_THROWS_WITH(
4558 gru(input, hx),
4559 "hidden0 has inconsistent hidden_size: got 1, expected 2");
4560 }
4561
4562 {
4563 auto input = torch::randn({3, 1, 1, 1, 1});
4564 auto hx = torch::randn({3, 2});
4565 ASSERT_THROWS_WITH(
4566 gru(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4567 }
4568
4569 {
4570 auto input = torch::randn({3, 1});
4571 auto hx = torch::randn({3, 1, 1, 1, 2});
4572 ASSERT_THROWS_WITH(
4573 gru(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4574 }
4575 }
4576
TEST_F(ModulesTest,PrettyPrintLinear)4577 TEST_F(ModulesTest, PrettyPrintLinear) {
4578 ASSERT_EQ(
4579 c10::str(Linear(3, 4)),
4580 "torch::nn::Linear(in_features=3, out_features=4, bias=true)");
4581 }
4582
TEST_F(ModulesTest,PrettyPrintBilinear)4583 TEST_F(ModulesTest, PrettyPrintBilinear) {
4584 ASSERT_EQ(
4585 c10::str(Bilinear(3, 2, 4)),
4586 "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)");
4587 ASSERT_EQ(
4588 c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))),
4589 "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)");
4590 }
4591
TEST_F(ModulesTest,PrettyPrintConv)4592 TEST_F(ModulesTest, PrettyPrintConv) {
4593 ASSERT_EQ(
4594 c10::str(Conv1d(3, 4, 5)),
4595 "torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
4596
4597 ASSERT_EQ(
4598 c10::str(Conv2d(3, 4, 5)),
4599 "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4600 ASSERT_EQ(
4601 c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
4602 "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4603 {
4604 const auto options =
4605 Conv2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4606 ASSERT_EQ(
4607 c10::str(Conv2d(options)),
4608 "torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4609 }
4610
4611 ASSERT_EQ(
4612 c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4613 "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4614 {
4615 const auto options = Conv3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4616 .stride({1, 2, 3})
4617 .padding(1)
4618 .dilation(0)
4619 .groups(2)
4620 .bias(false)
4621 .padding_mode(torch::kCircular);
4622 ASSERT_EQ(
4623 c10::str(Conv3d(options)),
4624 "torch::nn::Conv3d("
4625 "4, "
4626 "4, "
4627 "kernel_size=[5, 6, 7], "
4628 "stride=[1, 2, 3], "
4629 "padding=[1, 1, 1], "
4630 "dilation=[0, 0, 0], "
4631 "groups=2, "
4632 "bias=false, "
4633 "padding_mode=kCircular)");
4634 }
4635 }
4636
TEST_F(ModulesTest,PrettyPrintConvTranspose)4637 TEST_F(ModulesTest, PrettyPrintConvTranspose) {
4638 ASSERT_EQ(
4639 c10::str(ConvTranspose1d(3, 4, 5)),
4640 "torch::nn::ConvTranspose1d(3, 4, kernel_size=5, stride=1)");
4641
4642 ASSERT_EQ(
4643 c10::str(ConvTranspose2d(3, 4, 5)),
4644 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4645 ASSERT_EQ(
4646 c10::str(ConvTranspose2d(ConvTranspose2dOptions(3, 4, 5).stride(2))),
4647 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4648 {
4649 const auto options =
4650 ConvTranspose2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4651 ASSERT_EQ(
4652 c10::str(ConvTranspose2d(options)),
4653 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4654 }
4655
4656 ASSERT_EQ(
4657 c10::str(ConvTranspose3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4658 "torch::nn::ConvTranspose3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4659 {
4660 const auto options =
4661 ConvTranspose3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4662 .stride({1, 2, 3})
4663 .padding(1)
4664 .dilation(0)
4665 .groups(2)
4666 .bias(false)
4667 .padding_mode(torch::kCircular);
4668 ASSERT_EQ(
4669 c10::str(ConvTranspose3d(options)),
4670 "torch::nn::ConvTranspose3d("
4671 "4, "
4672 "4, "
4673 "kernel_size=[5, 6, 7], "
4674 "stride=[1, 2, 3], "
4675 "padding=[1, 1, 1], "
4676 "dilation=[0, 0, 0], "
4677 "groups=2, "
4678 "bias=false, "
4679 "padding_mode=kCircular)");
4680 }
4681 }
4682
TEST_F(ModulesTest,PrettyPrintUpsample)4683 TEST_F(ModulesTest, PrettyPrintUpsample) {
4684 ASSERT_EQ(
4685 c10::str(
4686 Upsample(UpsampleOptions().size(std::vector<int64_t>({2, 4, 4})))),
4687 "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)");
4688 ASSERT_EQ(
4689 c10::str(Upsample(UpsampleOptions()
4690 .scale_factor(std::vector<double>({0.5, 1.5}))
4691 .mode(torch::kBilinear))),
4692 "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)");
4693 }
4694
TEST_F(ModulesTest,PrettyPrintFold)4695 TEST_F(ModulesTest, PrettyPrintFold) {
4696 ASSERT_EQ(
4697 c10::str(Fold(FoldOptions({2, 2}, {5, 5}))),
4698 "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4699 ASSERT_EQ(
4700 c10::str(Fold(
4701 FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))),
4702 "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4703 }
4704
TEST_F(ModulesTest,PrettyPrintUnfold)4705 TEST_F(ModulesTest, PrettyPrintUnfold) {
4706 ASSERT_EQ(
4707 c10::str(Unfold(torch::IntArrayRef({2, 4}))),
4708 "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4709 ASSERT_EQ(
4710 c10::str(
4711 Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))),
4712 "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4713 }
4714
TEST_F(ModulesTest,PrettyPrintMaxPool)4715 TEST_F(ModulesTest, PrettyPrintMaxPool) {
4716 ASSERT_EQ(
4717 c10::str(MaxPool1d(5)),
4718 "torch::nn::MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=false)");
4719 ASSERT_EQ(
4720 c10::str(MaxPool2d(5)),
4721 "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4722 ASSERT_EQ(
4723 c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))),
4724 "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4725 ASSERT_EQ(
4726 c10::str(MaxPool3d(5)),
4727 "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4728 ASSERT_EQ(
4729 c10::str(MaxPool3d(MaxPool3dOptions(5).stride(2))),
4730 "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4731
4732 const auto options =
4733 MaxPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4734 ASSERT_EQ(
4735 c10::str(MaxPool2d(options)),
4736 "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4737 }
4738
TEST_F(ModulesTest,PrettyPrintAvgPool)4739 TEST_F(ModulesTest, PrettyPrintAvgPool) {
4740 ASSERT_EQ(
4741 c10::str(AvgPool1d(5)),
4742 "torch::nn::AvgPool1d(kernel_size=5, stride=5, padding=0)");
4743 ASSERT_EQ(
4744 c10::str(AvgPool2d(5)),
4745 "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4746 ASSERT_EQ(
4747 c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))),
4748 "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0])");
4749 ASSERT_EQ(
4750 c10::str(AvgPool3d(5)),
4751 "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0])");
4752 ASSERT_EQ(
4753 c10::str(AvgPool3d(AvgPool3dOptions(5).stride(2))),
4754 "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0])");
4755
4756 const auto options =
4757 AvgPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4758 ASSERT_EQ(
4759 c10::str(AvgPool2d(options)),
4760 "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0])");
4761 }
4762
TEST_F(ModulesTest,PrettyPrinFractionalMaxPool)4763 TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) {
4764 ASSERT_EQ(
4765 c10::str(
4766 FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))),
4767 "torch::nn::FractionalMaxPool2d()");
4768 ASSERT_EQ(
4769 c10::str(
4770 FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))),
4771 "torch::nn::FractionalMaxPool3d()");
4772 }
4773
TEST_F(ModulesTest,PrettyPrintLPPool)4774 TEST_F(ModulesTest, PrettyPrintLPPool) {
4775 ASSERT_EQ(
4776 c10::str(LPPool1d(2, 5)),
4777 "torch::nn::LPPool1d(norm_type=2, kernel_size=5, stride=5, ceil_mode=false)");
4778 ASSERT_EQ(
4779 c10::str(LPPool1d(LPPool1dOptions(1, 2).stride(5).ceil_mode(true))),
4780 "torch::nn::LPPool1d(norm_type=1, kernel_size=2, stride=5, ceil_mode=true)");
4781 ASSERT_EQ(
4782 c10::str(LPPool2d(2, std::vector<int64_t>({1, 2}))),
4783 "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)");
4784 ASSERT_EQ(
4785 c10::str(LPPool2d(LPPool2dOptions(1, std::vector<int64_t>({3, 4}))
4786 .stride({5, 6})
4787 .ceil_mode(true))),
4788 "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
4789 ASSERT_EQ(
4790 c10::str(LPPool3d(2, std::vector<int64_t>({1, 2, 3}))),
4791 "torch::nn::LPPool3d(norm_type=2, kernel_size=[1, 2, 3], stride=[1, 2, 3], ceil_mode=false)");
4792 ASSERT_EQ(
4793 c10::str(LPPool3d(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5}))
4794 .stride({5, 6, 7})
4795 .ceil_mode(true))),
4796 "torch::nn::LPPool3d(norm_type=1, kernel_size=[3, 4, 5], stride=[5, 6, 7], ceil_mode=true)");
4797 }
4798
TEST_F(ModulesTest,PrettyPrintAdaptiveMaxPool)4799 TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
4800 ASSERT_EQ(
4801 c10::str(AdaptiveMaxPool1d(5)),
4802 "torch::nn::AdaptiveMaxPool1d(output_size=5)");
4803
4804 const auto options = AdaptiveMaxPool1dOptions(3);
4805 ASSERT_EQ(
4806 c10::str(AdaptiveMaxPool1d(options)),
4807 "torch::nn::AdaptiveMaxPool1d(output_size=3)");
4808
4809 ASSERT_EQ(
4810 c10::str(AdaptiveMaxPool2d(5)),
4811 "torch::nn::AdaptiveMaxPool2d(output_size=[5, 5])");
4812 ASSERT_EQ(
4813 c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, 6}))),
4814 "torch::nn::AdaptiveMaxPool2d(output_size=[5, 6])");
4815 ASSERT_EQ(
4816 c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, std::nullopt}))),
4817 "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])");
4818 ASSERT_EQ(
4819 c10::str(AdaptiveMaxPool2d(
4820 AdaptiveMaxPool2dOptions({std::nullopt, std::nullopt}))),
4821 "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])");
4822
4823 ASSERT_EQ(
4824 c10::str(AdaptiveMaxPool3d(5)),
4825 "torch::nn::AdaptiveMaxPool3d(output_size=[5, 5, 5])");
4826 ASSERT_EQ(
4827 c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))),
4828 "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])");
4829 ASSERT_EQ(
4830 c10::str(
4831 AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, std::nullopt, 7}))),
4832 "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])");
4833 ASSERT_EQ(
4834 c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions(
4835 {std::nullopt, std::nullopt, std::nullopt}))),
4836 "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])");
4837 }
4838
TEST_F(ModulesTest,PrettyPrintAdaptiveAvgPool)4839 TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
4840 ASSERT_EQ(
4841 c10::str(AdaptiveAvgPool1d(5)),
4842 "torch::nn::AdaptiveAvgPool1d(output_size=5)");
4843
4844 ASSERT_EQ(
4845 c10::str(AdaptiveAvgPool2d(5)),
4846 "torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
4847 ASSERT_EQ(
4848 c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, 6}))),
4849 "torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
4850 ASSERT_EQ(
4851 c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, std::nullopt}))),
4852 "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])");
4853 ASSERT_EQ(
4854 c10::str(AdaptiveAvgPool2d(
4855 AdaptiveAvgPool2dOptions({std::nullopt, std::nullopt}))),
4856 "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])");
4857
4858 ASSERT_EQ(
4859 c10::str(AdaptiveAvgPool3d(5)),
4860 "torch::nn::AdaptiveAvgPool3d(output_size=[5, 5, 5])");
4861 ASSERT_EQ(
4862 c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))),
4863 "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])");
4864 ASSERT_EQ(
4865 c10::str(
4866 AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, std::nullopt, 7}))),
4867 "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])");
4868 ASSERT_EQ(
4869 c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions(
4870 {std::nullopt, std::nullopt, std::nullopt}))),
4871 "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])");
4872 }
4873
TEST_F(ModulesTest,PrettyPrintMaxUnpool)4874 TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
4875 ASSERT_EQ(
4876 c10::str(MaxUnpool1d(5)),
4877 "torch::nn::MaxUnpool1d(kernel_size=5, stride=5, padding=0)");
4878 ASSERT_EQ(
4879 c10::str(MaxUnpool1d(MaxUnpool1dOptions(5).stride(3).padding(1))),
4880 "torch::nn::MaxUnpool1d(kernel_size=5, stride=3, padding=1)");
4881
4882 ASSERT_EQ(
4883 c10::str(MaxUnpool2d(5)),
4884 "torch::nn::MaxUnpool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4885 ASSERT_EQ(
4886 c10::str(MaxUnpool2d(std::vector<int64_t>{5, 6})),
4887 "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])");
4888 ASSERT_EQ(
4889 c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector<int64_t>{5, 6})
4890 .stride({3, 4})
4891 .padding({1, 2}))),
4892 "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])");
4893 }
4894
TEST_F(ModulesTest,PrettyPrintDropout)4895 TEST_F(ModulesTest, PrettyPrintDropout) {
4896 ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)");
4897 ASSERT_EQ(
4898 c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)");
4899 ASSERT_EQ(
4900 c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))),
4901 "torch::nn::Dropout(p=0.42, inplace=true)");
4902 }
4903
TEST_F(ModulesTest,PrettyPrintDropout2d)4904 TEST_F(ModulesTest, PrettyPrintDropout2d) {
4905 ASSERT_EQ(
4906 c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)");
4907 ASSERT_EQ(
4908 c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)");
4909 ASSERT_EQ(
4910 c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))),
4911 "torch::nn::Dropout2d(p=0.42, inplace=true)");
4912 }
4913
TEST_F(ModulesTest,PrettyPrintDropout3d)4914 TEST_F(ModulesTest, PrettyPrintDropout3d) {
4915 ASSERT_EQ(
4916 c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)");
4917 ASSERT_EQ(
4918 c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)");
4919 ASSERT_EQ(
4920 c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))),
4921 "torch::nn::Dropout3d(p=0.42, inplace=true)");
4922 }
4923
TEST_F(ModulesTest,PrettyPrintFunctional)4924 TEST_F(ModulesTest, PrettyPrintFunctional) {
4925 ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
4926 }
4927
TEST_F(ModulesTest,PrettyPrintBatchNorm1d)4928 TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
4929 ASSERT_EQ(
4930 c10::str(BatchNorm1d(BatchNorm1dOptions(4)
4931 .eps(0.5)
4932 .momentum(0.1)
4933 .affine(false)
4934 .track_running_stats(true))),
4935 "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4936 }
4937
TEST_F(ModulesTest,PrettyPrintBatchNorm2d)4938 TEST_F(ModulesTest, PrettyPrintBatchNorm2d) {
4939 ASSERT_EQ(
4940 c10::str(BatchNorm2d(BatchNorm2dOptions(4)
4941 .eps(0.5)
4942 .momentum(0.1)
4943 .affine(false)
4944 .track_running_stats(true))),
4945 "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4946 }
4947
TEST_F(ModulesTest,PrettyPrintBatchNorm3d)4948 TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
4949 ASSERT_EQ(
4950 c10::str(BatchNorm3d(BatchNorm3dOptions(4)
4951 .eps(0.5)
4952 .momentum(0.1)
4953 .affine(false)
4954 .track_running_stats(true))),
4955 "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4956 }
4957
TEST_F(ModulesTest,PrettyPrintInstanceNorm1d)4958 TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
4959 ASSERT_EQ(
4960 c10::str(InstanceNorm1d(InstanceNorm1dOptions(4)
4961 .eps(0.5)
4962 .momentum(0.1)
4963 .affine(false)
4964 .track_running_stats(true))),
4965 "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4966 }
4967
TEST_F(ModulesTest,PrettyPrintInstanceNorm2d)4968 TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
4969 ASSERT_EQ(
4970 c10::str(InstanceNorm2d(InstanceNorm2dOptions(4)
4971 .eps(0.5)
4972 .momentum(0.1)
4973 .affine(false)
4974 .track_running_stats(true))),
4975 "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4976 }
4977
TEST_F(ModulesTest,PrettyPrintInstanceNorm3d)4978 TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
4979 ASSERT_EQ(
4980 c10::str(InstanceNorm3d(InstanceNorm3dOptions(4)
4981 .eps(0.5)
4982 .momentum(0.1)
4983 .affine(false)
4984 .track_running_stats(true))),
4985 "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4986 }
4987
TEST_F(ModulesTest,PrettyPrintLayerNorm)4988 TEST_F(ModulesTest, PrettyPrintLayerNorm) {
4989 ASSERT_EQ(
4990 c10::str(LayerNorm(LayerNormOptions({2, 2}))),
4991 "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)");
4992 ASSERT_EQ(
4993 c10::str(LayerNorm(
4994 LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))),
4995 "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
4996 }
4997
TEST_F(ModulesTest,PrettyPrintGroupNorm)4998 TEST_F(ModulesTest, PrettyPrintGroupNorm) {
4999 ASSERT_EQ(
5000 c10::str(GroupNorm(GroupNormOptions(2, 2))),
5001 "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)");
5002 ASSERT_EQ(
5003 c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))),
5004 "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)");
5005 }
5006
TEST_F(ModulesTest,PrettyPrintLocalResponseNorm)5007 TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
5008 ASSERT_EQ(
5009 c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
5010 "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
5011 ASSERT_EQ(
5012 c10::str(LocalResponseNorm(
5013 LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
5014 "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
5015 }
5016
TEST_F(ModulesTest,PrettyPrintEmbedding)5017 TEST_F(ModulesTest, PrettyPrintEmbedding) {
5018 ASSERT_EQ(
5019 c10::str(Embedding(EmbeddingOptions(10, 2))),
5020 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
5021 ASSERT_EQ(
5022 c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),
5023 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
5024 ASSERT_EQ(
5025 c10::str(Embedding(EmbeddingOptions(10, 2)
5026 .padding_idx(3)
5027 .max_norm(2)
5028 .norm_type(2.5)
5029 .scale_grad_by_freq(true)
5030 .sparse(true))),
5031 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5032 }
5033
TEST_F(ModulesTest,PrettyPrintEmbeddingBag)5034 TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
5035 ASSERT_EQ(
5036 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))),
5037 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)");
5038 ASSERT_EQ(
5039 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))),
5040 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)");
5041 ASSERT_EQ(
5042 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5043 .max_norm(2)
5044 .norm_type(2.5)
5045 .scale_grad_by_freq(true)
5046 .sparse(true))),
5047 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5048 ASSERT_EQ(
5049 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5050 .max_norm(2)
5051 .norm_type(2.5)
5052 .scale_grad_by_freq(true)
5053 .sparse(true)
5054 .mode(torch::kSum))),
5055 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)");
5056 ASSERT_EQ(
5057 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5058 .max_norm(2)
5059 .norm_type(2.5)
5060 .scale_grad_by_freq(true)
5061 .sparse(true)
5062 .mode(torch::kSum)
5063 .padding_idx(5))),
5064 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)");
5065 }
5066
TEST_F(ModulesTest,PrettyPrintL1Loss)5067 TEST_F(ModulesTest, PrettyPrintL1Loss) {
5068 ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()");
5069 }
TEST_F(ModulesTest,PrettyPrintKLDivLoss)5070 TEST_F(ModulesTest, PrettyPrintKLDivLoss) {
5071 ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()");
5072 }
TEST_F(ModulesTest,PrettyPrintMSELoss)5073 TEST_F(ModulesTest, PrettyPrintMSELoss) {
5074 ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()");
5075 }
TEST_F(ModulesTest,PrettyPrintBCELoss)5076 TEST_F(ModulesTest, PrettyPrintBCELoss) {
5077 ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()");
5078 }
TEST_F(ModulesTest,PrettyPrintHingeEmbeddingLoss)5079 TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) {
5080 ASSERT_EQ(
5081 c10::str(HingeEmbeddingLoss(HingeEmbeddingLossOptions().margin(4))),
5082 "torch::nn::HingeEmbeddingLoss(margin=4)");
5083 }
5084
TEST_F(ModulesTest,PrettyPrintCosineEmbeddingLoss)5085 TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) {
5086 ASSERT_EQ(
5087 c10::str(CosineEmbeddingLoss(CosineEmbeddingLossOptions().margin(0.25))),
5088 "torch::nn::CosineEmbeddingLoss(margin=0.25)");
5089 }
5090
TEST_F(ModulesTest,PrettyPrintTripletMarginLoss)5091 TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
5092 ASSERT_EQ(
5093 c10::str(TripletMarginLoss(
5094 TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))),
5095 "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
5096 }
5097
TEST_F(ModulesTest,PrettyPrintTripletMarginWithDistanceLoss)5098 TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
5099 auto distanceOptions = TripletMarginWithDistanceLossOptions()
5100 .distance_function([&](const torch::Tensor& x,
5101 const torch::Tensor& y) {
5102 return torch::pairwise_distance(x, y, 2.0, 1e-6);
5103 })
5104 .margin(1.5)
5105 .swap(true)
5106 .reduction(torch::kMean);
5107 ASSERT_EQ(
5108 c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
5109 "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
5110 }
5111
TEST_F(ModulesTest,PrettyPrintNLLLoss)5112 TEST_F(ModulesTest, PrettyPrintNLLLoss) {
5113 ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()");
5114 }
5115
TEST_F(ModulesTest,PrettyPrinCrossEntropyLoss)5116 TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) {
5117 ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()");
5118 }
5119
TEST_F(ModulesTest,PrettyPrintMultiLabelMarginLoss)5120 TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) {
5121 ASSERT_EQ(
5122 c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()");
5123 }
5124
TEST_F(ModulesTest,PrettyPrintMultiLabelSoftMarginLoss)5125 TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
5126 ASSERT_EQ(
5127 c10::str(MultiLabelSoftMarginLoss()),
5128 "torch::nn::MultiLabelSoftMarginLoss()");
5129 }
5130
TEST_F(ModulesTest,PrettyPrintSoftMarginLoss)5131 TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) {
5132 ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()");
5133 }
5134
TEST_F(ModulesTest,PrettyPrintCosineSimilarity)5135 TEST_F(ModulesTest, PrettyPrintCosineSimilarity) {
5136 ASSERT_EQ(
5137 c10::str(CosineSimilarity()),
5138 "torch::nn::CosineSimilarity(dim=1, eps=1e-08)");
5139 ASSERT_EQ(
5140 c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))),
5141 "torch::nn::CosineSimilarity(dim=0, eps=0.5)");
5142 }
5143
TEST_F(ModulesTest,PrettyPrintPairwiseDistance)5144 TEST_F(ModulesTest, PrettyPrintPairwiseDistance) {
5145 ASSERT_EQ(
5146 c10::str(PairwiseDistance()),
5147 "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)");
5148 ASSERT_EQ(
5149 c10::str(PairwiseDistance(
5150 PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))),
5151 "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)");
5152 }
5153
TEST_F(ModulesTest,PrettyPrintReflectionPad)5154 TEST_F(ModulesTest, PrettyPrintReflectionPad) {
5155 ASSERT_EQ(
5156 c10::str(ReflectionPad1d(ReflectionPad1dOptions(2))),
5157 "torch::nn::ReflectionPad1d(padding=[2, 2])");
5158 ASSERT_EQ(
5159 c10::str(ReflectionPad1d(ReflectionPad1dOptions({3, 1}))),
5160 "torch::nn::ReflectionPad1d(padding=[3, 1])");
5161 ASSERT_EQ(
5162 c10::str(ReflectionPad2d(ReflectionPad2dOptions(2))),
5163 "torch::nn::ReflectionPad2d(padding=[2, 2, 2, 2])");
5164 ASSERT_EQ(
5165 c10::str(ReflectionPad2d(ReflectionPad2dOptions({1, 1, 2, 0}))),
5166 "torch::nn::ReflectionPad2d(padding=[1, 1, 2, 0])");
5167 }
5168
TEST_F(ModulesTest,PrettyPrintReplicationPad)5169 TEST_F(ModulesTest, PrettyPrintReplicationPad) {
5170 ASSERT_EQ(
5171 c10::str(ReplicationPad1d(ReplicationPad1dOptions(2))),
5172 "torch::nn::ReplicationPad1d(padding=[2, 2])");
5173 ASSERT_EQ(
5174 c10::str(ReplicationPad1d(ReplicationPad1dOptions({3, 1}))),
5175 "torch::nn::ReplicationPad1d(padding=[3, 1])");
5176 ASSERT_EQ(
5177 c10::str(ReplicationPad2d(ReplicationPad2dOptions(2))),
5178 "torch::nn::ReplicationPad2d(padding=[2, 2, 2, 2])");
5179 ASSERT_EQ(
5180 c10::str(ReplicationPad2d(ReplicationPad2dOptions({1, 1, 2, 0}))),
5181 "torch::nn::ReplicationPad2d(padding=[1, 1, 2, 0])");
5182 ASSERT_EQ(
5183 c10::str(ReplicationPad3d(ReplicationPad3dOptions(1))),
5184 "torch::nn::ReplicationPad3d(padding=[1, 1, 1, 1, 1, 1])");
5185 ASSERT_EQ(
5186 c10::str(ReplicationPad3d(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}))),
5187 "torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
5188 }
5189
TEST_F(ModulesTest,PrettyPrintZeroPad)5190 TEST_F(ModulesTest, PrettyPrintZeroPad) {
5191 ASSERT_EQ(
5192 c10::str(ZeroPad1d(ZeroPad1dOptions(2))),
5193 "torch::nn::ZeroPad1d(padding=[2, 2])");
5194 ASSERT_EQ(
5195 c10::str(ZeroPad1d(ZeroPad1dOptions({3, 1}))),
5196 "torch::nn::ZeroPad1d(padding=[3, 1])");
5197 ASSERT_EQ(
5198 c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
5199 "torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
5200 ASSERT_EQ(
5201 c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
5202 "torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
5203 ASSERT_EQ(
5204 c10::str(ZeroPad3d(ZeroPad3dOptions(1))),
5205 "torch::nn::ZeroPad3d(padding=[1, 1, 1, 1, 1, 1])");
5206 ASSERT_EQ(
5207 c10::str(ZeroPad3d(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}))),
5208 "torch::nn::ZeroPad3d(padding=[1, 2, 1, 2, 1, 2])");
5209 }
5210
TEST_F(ModulesTest,PrettyPrintConstantPad)5211 TEST_F(ModulesTest, PrettyPrintConstantPad) {
5212 ASSERT_EQ(
5213 c10::str(ConstantPad1d(ConstantPad1dOptions(2, 3.5))),
5214 "torch::nn::ConstantPad1d(padding=[2, 2], value=3.5)");
5215 ASSERT_EQ(
5216 c10::str(ConstantPad1d(ConstantPad1dOptions({3, 1}, 3.5))),
5217 "torch::nn::ConstantPad1d(padding=[3, 1], value=3.5)");
5218 ASSERT_EQ(
5219 c10::str(ConstantPad2d(ConstantPad2dOptions(2, 3.5))),
5220 "torch::nn::ConstantPad2d(padding=[2, 2, 2, 2], value=3.5)");
5221 ASSERT_EQ(
5222 c10::str(ConstantPad2d(ConstantPad2dOptions({3, 0, 2, 1}, 3.5))),
5223 "torch::nn::ConstantPad2d(padding=[3, 0, 2, 1], value=3.5)");
5224 ASSERT_EQ(
5225 c10::str(ConstantPad3d(ConstantPad3dOptions(1, 3.5))),
5226 "torch::nn::ConstantPad3d(padding=[1, 1, 1, 1, 1, 1], value=3.5)");
5227 ASSERT_EQ(
5228 c10::str(ConstantPad3d(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5))),
5229 "torch::nn::ConstantPad3d(padding=[1, 2, 1, 2, 1, 2], value=3.5)");
5230 }
5231
TEST_F(ModulesTest,PrettyPrintNestedModel)5232 TEST_F(ModulesTest, PrettyPrintNestedModel) {
5233 struct InnerTestModule : torch::nn::Module {
5234 InnerTestModule()
5235 : torch::nn::Module("InnerTestModule"),
5236 fc(register_module("fc", torch::nn::Linear(3, 4))),
5237 table(register_module("table", torch::nn::Embedding(10, 2))) {}
5238
5239 torch::nn::Linear fc;
5240 torch::nn::Embedding table;
5241 };
5242
5243 struct TestModule : torch::nn::Module {
5244 TestModule()
5245 : torch::nn::Module("TestModule"),
5246 fc(register_module("fc", torch::nn::Linear(4, 5))),
5247 table(register_module(
5248 "table",
5249 torch::nn::Embedding(EmbeddingOptions(10, 2)))),
5250 inner(register_module("inner", std::make_shared<InnerTestModule>())) {
5251 }
5252
5253 torch::nn::Linear fc;
5254 torch::nn::Embedding table;
5255 std::shared_ptr<InnerTestModule> inner;
5256 };
5257
5258 ASSERT_EQ(
5259 c10::str(TestModule{}),
5260 "TestModule(\n"
5261 " (fc): torch::nn::Linear(in_features=4, out_features=5, bias=true)\n"
5262 " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5263 " (inner): InnerTestModule(\n"
5264 " (fc): torch::nn::Linear(in_features=3, out_features=4, bias=true)\n"
5265 " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5266 " )\n"
5267 ")");
5268 }
5269
TEST_F(ModulesTest,PrettyPrintELU)5270 TEST_F(ModulesTest, PrettyPrintELU) {
5271 ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)");
5272 ASSERT_EQ(
5273 c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))),
5274 "torch::nn::ELU(alpha=42.42, inplace=true)");
5275 }
5276
TEST_F(ModulesTest,PrettyPrintSELU)5277 TEST_F(ModulesTest, PrettyPrintSELU) {
5278 ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()");
5279 ASSERT_EQ(
5280 c10::str(SELU(SELUOptions().inplace(true))),
5281 "torch::nn::SELU(inplace=true)");
5282 }
5283
TEST_F(ModulesTest,PrettyPrintGLU)5284 TEST_F(ModulesTest, PrettyPrintGLU) {
5285 ASSERT_EQ(c10::str(GLU()), "torch::nn::GLU(dim=-1)");
5286 ASSERT_EQ(c10::str(GLU(1)), "torch::nn::GLU(dim=1)");
5287 }
5288
TEST_F(ModulesTest,PrettyPrintHardshrink)5289 TEST_F(ModulesTest, PrettyPrintHardshrink) {
5290 ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)");
5291 ASSERT_EQ(
5292 c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))),
5293 "torch::nn::Hardshrink(42.42)");
5294 }
5295
TEST_F(ModulesTest,PrettyPrintHardtanh)5296 TEST_F(ModulesTest, PrettyPrintHardtanh) {
5297 ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)");
5298 ASSERT_EQ(
5299 c10::str(Hardtanh(
5300 HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))),
5301 "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)");
5302 }
5303
TEST_F(ModulesTest,PrettyPrintLeakyReLU)5304 TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
5305 ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)");
5306 ASSERT_EQ(
5307 c10::str(
5308 LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))),
5309 "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
5310 }
5311
TEST_F(ModulesTest,PrettyPrintLogSigmoid)5312 TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
5313 ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
5314 }
5315
TEST_F(ModulesTest,PrettyPrintSoftmax)5316 TEST_F(ModulesTest, PrettyPrintSoftmax) {
5317 ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)");
5318 }
5319
TEST_F(ModulesTest,PrettyPrintSoftmin)5320 TEST_F(ModulesTest, PrettyPrintSoftmin) {
5321 ASSERT_EQ(c10::str(Softmin(SoftminOptions(1))), "torch::nn::Softmin(dim=1)");
5322 }
5323
TEST_F(ModulesTest,PrettyPrintLogSoftmax)5324 TEST_F(ModulesTest, PrettyPrintLogSoftmax) {
5325 ASSERT_EQ(
5326 c10::str(LogSoftmax(LogSoftmaxOptions(1))),
5327 "torch::nn::LogSoftmax(dim=1)");
5328 }
5329
TEST_F(ModulesTest,PrettyPrintSoftmax2d)5330 TEST_F(ModulesTest, PrettyPrintSoftmax2d) {
5331 ASSERT_EQ(c10::str(Softmax2d()), "torch::nn::Softmax2d()");
5332 }
5333
TEST_F(ModulesTest,PrettyPrintPReLU)5334 TEST_F(ModulesTest, PrettyPrintPReLU) {
5335 ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)");
5336 ASSERT_EQ(
5337 c10::str(PReLU(PReLUOptions().num_parameters(42))),
5338 "torch::nn::PReLU(num_parameters=42)");
5339 }
5340
TEST_F(ModulesTest,PrettyPrintReLU)5341 TEST_F(ModulesTest, PrettyPrintReLU) {
5342 ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()");
5343 ASSERT_EQ(
5344 c10::str(ReLU(ReLUOptions().inplace(true))),
5345 "torch::nn::ReLU(inplace=true)");
5346 ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)");
5347 }
5348
TEST_F(ModulesTest,PrettyPrintReLU6)5349 TEST_F(ModulesTest, PrettyPrintReLU6) {
5350 ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()");
5351 ASSERT_EQ(
5352 c10::str(ReLU6(ReLU6Options().inplace(true))),
5353 "torch::nn::ReLU6(inplace=true)");
5354 ASSERT_EQ(
5355 c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)");
5356 }
5357
TEST_F(ModulesTest,PrettyPrintRReLU)5358 TEST_F(ModulesTest, PrettyPrintRReLU) {
5359 ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)");
5360 ASSERT_EQ(
5361 c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))),
5362 "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)");
5363 }
5364
TEST_F(ModulesTest,PrettyPrintCELU)5365 TEST_F(ModulesTest, PrettyPrintCELU) {
5366 ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)");
5367 ASSERT_EQ(
5368 c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))),
5369 "torch::nn::CELU(alpha=42.42, inplace=true)");
5370 }
5371
TEST_F(ModulesTest,PrettyPrintSigmoid)5372 TEST_F(ModulesTest, PrettyPrintSigmoid) {
5373 ASSERT_EQ(c10::str(Sigmoid()), "torch::nn::Sigmoid()");
5374 }
5375
TEST_F(ModulesTest,PrettyPrintPixelShuffle)5376 TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
5377 ASSERT_EQ(
5378 c10::str(PixelShuffle(PixelShuffleOptions(5))),
5379 "torch::nn::PixelShuffle(upscale_factor=5)");
5380 }
5381
TEST_F(ModulesTest,PrettyPrintPixelUnshuffle)5382 TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
5383 ASSERT_EQ(
5384 c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
5385 "torch::nn::PixelUnshuffle(downscale_factor=5)");
5386 }
5387
TEST_F(ModulesTest,PrettyPrintSoftplus)5388 TEST_F(ModulesTest, PrettyPrintSoftplus) {
5389 ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)");
5390 ASSERT_EQ(
5391 c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))),
5392 "torch::nn::Softplus(beta=0.24, threshold=42.42)");
5393 }
5394
TEST_F(ModulesTest,PrettyPrintSoftshrink)5395 TEST_F(ModulesTest, PrettyPrintSoftshrink) {
5396 ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)");
5397 ASSERT_EQ(
5398 c10::str(Softshrink(SoftshrinkOptions(42.42))),
5399 "torch::nn::Softshrink(42.42)");
5400 }
5401
TEST_F(ModulesTest,PrettyPrintSoftsign)5402 TEST_F(ModulesTest, PrettyPrintSoftsign) {
5403 ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()");
5404 }
5405
TEST_F(ModulesTest,PrettyPrintTanh)5406 TEST_F(ModulesTest, PrettyPrintTanh) {
5407 ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()");
5408 }
5409
TEST_F(ModulesTest,PrettyPrintTanhshrink)5410 TEST_F(ModulesTest, PrettyPrintTanhshrink) {
5411 ASSERT_EQ(c10::str(Tanhshrink()), "torch::nn::Tanhshrink()");
5412 }
5413
TEST_F(ModulesTest,PrettyPrintThreshold)5414 TEST_F(ModulesTest, PrettyPrintThreshold) {
5415 ASSERT_EQ(
5416 c10::str(Threshold(24.24, 42.42)),
5417 "torch::nn::Threshold(threshold=24.24, value=42.42)");
5418 ASSERT_EQ(
5419 c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))),
5420 "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)");
5421 }
5422
TEST_F(ModulesTest,PrettyPrintCTCLoss)5423 TEST_F(ModulesTest, PrettyPrintCTCLoss) {
5424 ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()");
5425 ASSERT_EQ(
5426 c10::str(
5427 CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction(
5428 torch::kSum))),
5429 "torch::nn::CTCLoss()");
5430 }
5431
TEST_F(ModulesTest,PrettyPrintPoissonNLLLoss)5432 TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
5433 ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
5434 ASSERT_EQ(
5435 c10::str(PoissonNLLLoss(PoissonNLLLossOptions()
5436 .log_input(false)
5437 .full(true)
5438 .eps(0.42)
5439 .reduction(torch::kSum))),
5440 "torch::nn::PoissonNLLLoss()");
5441 }
5442
TEST_F(ModulesTest,PrettyPrintMarginRankingLoss)5443 TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) {
5444 ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()");
5445 ASSERT_EQ(
5446 c10::str(MarginRankingLoss(
5447 MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))),
5448 "torch::nn::MarginRankingLoss()");
5449 }
5450
TEST_F(ModulesTest,PrettyPrintCrossMapLRN2d)5451 TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) {
5452 ASSERT_EQ(
5453 c10::str(CrossMapLRN2d(4)),
5454 "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)");
5455 ASSERT_EQ(
5456 c10::str(
5457 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))),
5458 "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)");
5459 }
5460
TEST_F(ModulesTest,PrettyPrintAlphaDropout)5461 TEST_F(ModulesTest, PrettyPrintAlphaDropout) {
5462 ASSERT_EQ(
5463 c10::str(AlphaDropout()),
5464 "torch::nn::AlphaDropout(p=0.5, inplace=false)");
5465 ASSERT_EQ(
5466 c10::str(AlphaDropout(AlphaDropoutOptions(0.2))),
5467 "torch::nn::AlphaDropout(p=0.2, inplace=false)");
5468 ASSERT_EQ(
5469 c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))),
5470 "torch::nn::AlphaDropout(p=0.2, inplace=true)");
5471 }
5472
TEST_F(ModulesTest,PrettyPrintFeatureAlphaDropout)5473 TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) {
5474 ASSERT_EQ(
5475 c10::str(FeatureAlphaDropout()),
5476 "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)");
5477 ASSERT_EQ(
5478 c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))),
5479 "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)");
5480 ASSERT_EQ(
5481 c10::str(
5482 FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))),
5483 "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)");
5484 }
5485
TEST_F(ModulesTest,PrettyPrintBCEWithLogitsLoss)5486 TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) {
5487 ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()");
5488 ASSERT_EQ(
5489 c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions()
5490 .weight(torch::ones({3, 3}))
5491 .pos_weight(torch::ones({3, 3}))
5492 .reduction(torch::kSum))),
5493 "torch::nn::BCEWithLogitsLoss()");
5494 }
5495
TEST_F(ModulesTest,PrettyPrintMultiheadAttention)5496 TEST_F(ModulesTest, PrettyPrintMultiheadAttention) {
5497 ASSERT_EQ(
5498 c10::str(MultiheadAttention(20, 10)),
5499 "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)");
5500 ASSERT_EQ(
5501 c10::str(
5502 MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
5503 "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
5504 }
5505
TEST_F(ModulesTest,PrettyPrintRNNCell)5506 TEST_F(ModulesTest, PrettyPrintRNNCell) {
5507 ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)");
5508 ASSERT_EQ(
5509 c10::str(RNNCell(
5510 RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))),
5511 "torch::nn::RNNCell(20, 10, bias=false)");
5512 ASSERT_EQ(
5513 c10::str(RNNCell(
5514 RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))),
5515 "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)");
5516 }
5517
TEST_F(ModulesTest,PrettyPrintLSTMCell)5518 TEST_F(ModulesTest, PrettyPrintLSTMCell) {
5519 ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)");
5520 ASSERT_EQ(
5521 c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))),
5522 "torch::nn::LSTMCell(20, 10, bias=false)");
5523 }
5524
TEST_F(ModulesTest,PrettyPrintGRUCell)5525 TEST_F(ModulesTest, PrettyPrintGRUCell) {
5526 ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)");
5527 ASSERT_EQ(
5528 c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))),
5529 "torch::nn::GRUCell(20, 10, bias=false)");
5530 }
5531
TEST_F(ModulesTest,PrettyPrintAdaptiveLogSoftmaxWithLoss)5532 TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
5533 {
5534 AdaptiveLogSoftmaxWithLoss asfm(
5535 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
5536 ASSERT_EQ(
5537 c10::str(asfm),
5538 "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5539 " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
5540 " (tail): torch::nn::ModuleList(\n"
5541 " (0): torch::nn::Sequential(\n"
5542 " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5543 " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
5544 " )\n"
5545 " )\n"
5546 ")");
5547 }
5548 {
5549 AdaptiveLogSoftmaxWithLoss asfm(
5550 AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
5551 .div_value(2.)
5552 .head_bias(true));
5553 ASSERT_EQ(
5554 c10::str(asfm),
5555 "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5556 " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
5557 " (tail): torch::nn::ModuleList(\n"
5558 " (0): torch::nn::Sequential(\n"
5559 " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5560 " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
5561 " )\n"
5562 " (1): torch::nn::Sequential(\n"
5563 " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
5564 " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
5565 " )\n"
5566 " )\n"
5567 ")");
5568 }
5569 }
5570