xref: /aosp_15_r20/external/pytorch/test/cpp/api/functional.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 namespace F = torch::nn::functional;
9 
10 using namespace torch::nn;
11 
12 struct FunctionalTest : torch::test::SeedingFixture {};
13 
TEST_F(FunctionalTest,Conv1d)14 TEST_F(FunctionalTest, Conv1d) {
15   auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
16                .reshape({2, 3, 5});
17   auto weight =
18       torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true))
19           .reshape({2, 3, 3});
20   auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
21   auto expected = torch::tensor(
22       {{{312., 348., 384.}, {798., 915., 1032.}},
23 
24        {{852., 888., 924.}, {2553., 2670., 2787.}}},
25       torch::kFloat);
26   ASSERT_TRUE(torch::allclose(y, expected));
27 
28   auto y_no_options = F::conv1d(x, weight);
29   ASSERT_TRUE(torch::allclose(y_no_options, expected));
30 }
31 
TEST_F(FunctionalTest,Conv2dEven)32 TEST_F(FunctionalTest, Conv2dEven) {
33   auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
34                .reshape({1, 3, 5, 5});
35   auto weight =
36       torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true))
37           .reshape({2, 3, 3, 3});
38   auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
39   auto expected = torch::tensor(
40       {{{{15219., 15570., 15921.},
41          {16974., 17325., 17676.},
42          {18729., 19080., 19431.}},
43 
44         {{37818., 38898., 39978.},
45          {43218., 44298., 45378.},
46          {48618., 49698., 50778.}}}},
47       torch::kFloat);
48   ASSERT_TRUE(torch::allclose(y, expected));
49 
50   auto y_no_options = F::conv2d(x, weight);
51   ASSERT_TRUE(torch::allclose(y_no_options, expected));
52 }
53 
TEST_F(FunctionalTest,Conv2dUneven)54 TEST_F(FunctionalTest, Conv2dUneven) {
55   auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
56                .reshape({1, 3, 5, 4});
57   auto weight =
58       torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true))
59           .reshape({2, 3, 3, 2});
60   auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
61   auto expected = torch::tensor(
62       {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
63 
64         {{13227., 13704., 14181.},
65          {15135., 15612., 16089.},
66          {17043., 17520., 17997.}}}},
67       torch::kFloat);
68   ASSERT_TRUE(torch::allclose(y, expected));
69 
70   auto y_no_options = F::conv2d(x, weight);
71   ASSERT_TRUE(torch::allclose(y_no_options, expected));
72 }
73 
TEST_F(FunctionalTest,Conv3d)74 TEST_F(FunctionalTest, Conv3d) {
75   auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
76                .reshape({1, 3, 5, 5, 5});
77   auto weight =
78       torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true))
79           .reshape({2, 3, 3, 3, 3});
80   auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
81   auto expected = torch::tensor(
82       {{{{{700704., 703944., 707184.},
83           {716904., 720144., 723384.},
84           {733104., 736344., 739584.}},
85 
86          {{781704., 784944., 788184.},
87           {797904., 801144., 804384.},
88           {814104., 817344., 820584.}},
89 
90          {{862704., 865944., 869184.},
91           {878904., 882144., 885384.},
92           {895104., 898344., 901584.}}},
93 
94         {{{1724220., 1734021., 1743822.},
95           {1773225., 1783026., 1792827.},
96           {1822230., 1832031., 1841832.}},
97 
98          {{1969245., 1979046., 1988847.},
99           {2018250., 2028051., 2037852.},
100           {2067255., 2077056., 2086857.}},
101 
102          {{2214270., 2224071., 2233872.},
103           {2263275., 2273076., 2282877.},
104           {2312280., 2322081., 2331882.}}}}},
105       torch::kFloat);
106   ASSERT_TRUE(torch::allclose(y, expected));
107 
108   auto y_no_options = F::conv3d(x, weight);
109   ASSERT_TRUE(torch::allclose(y_no_options, expected));
110 }
111 
TEST_F(FunctionalTest,MaxPool1d)112 TEST_F(FunctionalTest, MaxPool1d) {
113   auto x = torch::ones({1, 1, 5});
114   auto y = F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2));
115 
116   ASSERT_EQ(y.ndimension(), 3);
117   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
118   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
119 }
120 
TEST_F(FunctionalTest,MaxPool2d)121 TEST_F(FunctionalTest, MaxPool2d) {
122   auto x = torch::ones({2, 5, 5});
123   auto y = F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2));
124 
125   ASSERT_EQ(y.ndimension(), 3);
126   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
127   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
128 }
129 
TEST_F(FunctionalTest,MaxPool2dBackward)130 TEST_F(FunctionalTest, MaxPool2dBackward) {
131   auto input = torch::rand(
132       {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true));
133   auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2));
134   auto s = output.sum();
135   s.backward();
136   ASSERT_TRUE(input.sizes() == input.grad().sizes());
137 }
138 
TEST_F(FunctionalTest,MaxPool3d)139 TEST_F(FunctionalTest, MaxPool3d) {
140   auto x = torch::ones({2, 5, 5, 5});
141   auto y = F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2));
142 
143   ASSERT_EQ(y.ndimension(), 4);
144   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
145   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
146 }
147 
TEST_F(FunctionalTest,AvgPool1d)148 TEST_F(FunctionalTest, AvgPool1d) {
149   auto x = torch::ones({1, 1, 5});
150   auto y = F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2));
151 
152   ASSERT_EQ(y.ndimension(), 3);
153   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
154   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
155 }
156 
TEST_F(FunctionalTest,AvgPool2d)157 TEST_F(FunctionalTest, AvgPool2d) {
158   auto x = torch::ones({2, 5, 5});
159   auto y = F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2));
160 
161   ASSERT_EQ(y.ndimension(), 3);
162   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
163   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
164 }
165 
TEST_F(FunctionalTest,AvgPool3d)166 TEST_F(FunctionalTest, AvgPool3d) {
167   auto x = torch::ones({2, 5, 5, 5});
168   auto y = F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2));
169 
170   ASSERT_EQ(y.ndimension(), 4);
171   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
172   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
173 }
174 
TEST_F(FunctionalTest,FractionalMaxPool2d)175 TEST_F(FunctionalTest, FractionalMaxPool2d) {
176   auto x = torch::ones({2, 5, 5});
177   auto y = F::fractional_max_pool2d(
178       x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
179 
180   ASSERT_EQ(y.ndimension(), 3);
181   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
182   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
183 
184   auto y_with_indices = F::fractional_max_pool2d_with_indices(
185       x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
186   ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
187   ASSERT_TRUE(torch::allclose(
188       std::get<1>(y_with_indices),
189       torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
190   ASSERT_EQ(
191       std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2}));
192 
193   auto x1 = torch::ones({2, 2, 5, 5});
194   auto y1 = F::fractional_max_pool2d(
195       x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
196 
197   ASSERT_EQ(y1.ndimension(), 4);
198   ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2})));
199   ASSERT_EQ(y1.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
200 
201   auto y1_with_indices = F::fractional_max_pool2d_with_indices(
202       x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
203   ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices)));
204   ASSERT_TRUE(torch::allclose(
205       std::get<1>(y1_with_indices),
206       torch::tensor(
207           {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}},
208            {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}})));
209   ASSERT_EQ(
210       std::get<1>(y1_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
211 }
212 
TEST_F(FunctionalTest,FractionalMaxPool3d)213 TEST_F(FunctionalTest, FractionalMaxPool3d) {
214   auto x = torch::ones({2, 5, 5, 5});
215   auto y = F::fractional_max_pool3d(
216       x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
217 
218   ASSERT_EQ(y.ndimension(), 4);
219   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
220   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
221 
222   auto y_with_indices = F::fractional_max_pool3d_with_indices(
223       x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
224   ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
225   ASSERT_TRUE(torch::allclose(
226       std::get<1>(y_with_indices),
227       torch::tensor(
228           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
229            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
230   ASSERT_EQ(
231       std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
232 }
233 
TEST_F(FunctionalTest,LPPool1d)234 TEST_F(FunctionalTest, LPPool1d) {
235   int norm_type = 2;
236   int stride = 2;
237   int kernel_size = 3;
238 
239   auto x = torch::ones({1, 1, 5});
240   auto y = F::lp_pool1d(
241       x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride));
242   auto expected =
243       (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
244        kernel_size)
245           .pow(1. / norm_type);
246 
247   ASSERT_EQ(y.ndimension(), 3);
248   ASSERT_TRUE(torch::allclose(y, expected));
249   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
250 }
251 
TEST_F(FunctionalTest,LPPool2d)252 TEST_F(FunctionalTest, LPPool2d) {
253   int norm_type = 2;
254   int stride = 2;
255   std::vector<int64_t> kernel_size({2, 3});
256 
257   auto x = torch::ones({1, 1, 2, 5});
258   auto y = F::lp_pool2d(
259       x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
260   auto expected =
261       (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
262        (kernel_size[0] * kernel_size[1]))
263           .pow(1. / norm_type);
264 
265   ASSERT_EQ(y.ndimension(), 4);
266   ASSERT_TRUE(torch::allclose(y, expected));
267   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
268 }
269 
TEST_F(FunctionalTest,LPPool3d)270 TEST_F(FunctionalTest, LPPool3d) {
271   int norm_type = 2;
272   int stride = 2;
273   std::vector<int64_t> kernel_size({1, 2, 3});
274 
275   auto x = torch::ones({1, 1, 1, 2, 5});
276   auto y = F::lp_pool3d(
277       x, F::LPPool3dFuncOptions(norm_type, kernel_size).stride(stride));
278   auto expected =
279       (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
280        (kernel_size[0] * kernel_size[1] * kernel_size[2]))
281           .pow(1. / norm_type);
282 
283   ASSERT_EQ(y.ndimension(), 5);
284   ASSERT_TRUE(torch::allclose(y, expected));
285   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
286 }
287 
TEST_F(FunctionalTest,CosineSimilarity)288 TEST_F(FunctionalTest, CosineSimilarity) {
289   auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
290   auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
291   auto output = F::cosine_similarity(
292       input1, input2, F::CosineSimilarityFuncOptions().dim(1));
293   auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
294   ASSERT_TRUE(output.allclose(expected, 1e-04));
295 }
296 
TEST_F(FunctionalTest,SmoothL1LossDefaultOptions)297 TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
298   auto input = torch::tensor(
299       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
300   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
301   auto output = F::smooth_l1_loss(input, target);
302   auto expected = torch::tensor(0.0233335, torch::kFloat);
303   auto s = output.sum();
304   s.backward();
305   ASSERT_TRUE(output.allclose(expected));
306   ASSERT_TRUE(input.sizes() == input.grad().sizes());
307 }
308 
TEST_F(FunctionalTest,SmoothL1LossBeta)309 TEST_F(FunctionalTest, SmoothL1LossBeta) {
310   auto input = torch::tensor(
311       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
312   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
313   auto output =
314       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment)
315       F::smooth_l1_loss(
316           input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
317   auto expected = torch::tensor(1.67, torch::kFloat);
318   auto s = output.sum();
319   s.backward();
320   ASSERT_TRUE(output.allclose(expected));
321   ASSERT_TRUE(input.sizes() == input.grad().sizes());
322 }
323 
TEST_F(FunctionalTest,SmoothL1LossBetaOptions)324 TEST_F(FunctionalTest, SmoothL1LossBetaOptions) {
325   auto input = torch::tensor(
326       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
327   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
328   auto output =
329       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
330       F::smooth_l1_loss(
331           input,
332           target,
333           F::SmoothL1LossFuncOptions().reduction(torch::kMean).beta(0.5));
334   auto expected = torch::tensor(1.67, torch::kFloat);
335   auto s = output.sum();
336   s.backward();
337   ASSERT_TRUE(output.allclose(expected));
338   ASSERT_TRUE(input.sizes() == input.grad().sizes());
339 }
340 
TEST_F(FunctionalTest,SmoothL1LossNoReduction)341 TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
342   auto input = torch::tensor(
343       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
344   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
345   auto output =
346       // NOLINTNEXTLINE(bugprone-argument-comment)
347       F::smooth_l1_loss(input, target, /*reduction=*/torch::kNone);
348   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
349   auto s = output.sum();
350   s.backward();
351   ASSERT_TRUE(output.allclose(expected));
352   ASSERT_TRUE(input.sizes() == input.grad().sizes());
353 }
354 
TEST_F(FunctionalTest,HuberLossDefaultOptions)355 TEST_F(FunctionalTest, HuberLossDefaultOptions) {
356   auto input = torch::tensor(
357       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
358   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
359   auto output = F::huber_loss(input, target);
360   auto expected = torch::tensor(0.0233335, torch::kFloat);
361   auto s = output.sum();
362   s.backward();
363   ASSERT_TRUE(output.allclose(expected));
364   ASSERT_TRUE(input.sizes() == input.grad().sizes());
365 }
366 
TEST_F(FunctionalTest,HuberLossDelta)367 TEST_F(FunctionalTest, HuberLossDelta) {
368   auto input = torch::tensor(
369       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
370   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
371   auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5);
372   auto output = F::huber_loss(input, target, options);
373   auto expected = torch::tensor(1.67 * 0.5, torch::kFloat);
374   auto s = output.sum();
375   s.backward();
376   ASSERT_TRUE(output.allclose(expected));
377   ASSERT_TRUE(input.sizes() == input.grad().sizes());
378 }
379 
TEST_F(FunctionalTest,HuberLossNoReduction)380 TEST_F(FunctionalTest, HuberLossNoReduction) {
381   auto input = torch::tensor(
382       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
383   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
384   auto options = F::HuberLossFuncOptions().reduction(torch::kNone);
385   auto output = F::huber_loss(input, target, options);
386   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
387   auto s = output.sum();
388   s.backward();
389   ASSERT_TRUE(output.allclose(expected));
390   ASSERT_TRUE(input.sizes() == input.grad().sizes());
391 }
392 
TEST_F(FunctionalTest,SoftMarginLossDefaultOptions)393 TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) {
394   auto input = torch::tensor(
395       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
396   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
397   auto output = F::soft_margin_loss(input, target);
398   auto expected = torch::tensor({1.3767317}, torch::kFloat);
399   auto s = output.sum();
400   s.backward();
401 
402   ASSERT_TRUE(output.allclose(expected));
403   ASSERT_EQ(input.sizes(), input.grad().sizes());
404 }
405 
TEST_F(FunctionalTest,MultiLabelSoftMarginLossDefaultOptions)406 TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
407   auto input = torch::tensor(
408       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
409       torch::dtype(torch::kFloat).requires_grad(true));
410   auto target =
411       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
412   auto output = F::multilabel_soft_margin_loss(input, target);
413   auto expected = torch::tensor({0.7608436}, torch::kFloat);
414   auto s = output.sum();
415   s.backward();
416 
417   ASSERT_TRUE(output.allclose(expected));
418   ASSERT_EQ(input.sizes(), input.grad().sizes());
419 }
420 
TEST_F(FunctionalTest,SoftMarginLossNoReduction)421 TEST_F(FunctionalTest, SoftMarginLossNoReduction) {
422   auto input = torch::tensor(
423       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
424   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
425   auto output = F::soft_margin_loss(input, target, torch::kNone);
426   auto expected = torch::tensor(
427       {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
428   auto s = output.sum();
429   s.backward();
430 
431   ASSERT_TRUE(output.allclose(expected));
432   ASSERT_EQ(input.sizes(), input.grad().sizes());
433 }
434 
TEST_F(FunctionalTest,MultiLabelSoftMarginLossWeightedNoReduction)435 TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) {
436   auto input = torch::tensor(
437       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
438       torch::dtype(torch::kFloat).requires_grad(true));
439   auto target =
440       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
441   auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
442   auto options = F::MultilabelSoftMarginLossFuncOptions()
443                      .reduction(torch::kNone)
444                      .weight(weight);
445   auto output = F::multilabel_soft_margin_loss(input, target, options);
446   auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
447   auto s = output.sum();
448   s.backward();
449 
450   ASSERT_TRUE(output.allclose(expected));
451   ASSERT_EQ(input.sizes(), input.grad().sizes());
452 }
453 
TEST_F(FunctionalTest,PairwiseDistance)454 TEST_F(FunctionalTest, PairwiseDistance) {
455   auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
456   auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
457   auto output = F::pairwise_distance(
458       input1, input2, F::PairwiseDistanceFuncOptions().p(1));
459   auto expected = torch::tensor({6, 6}, torch::kFloat);
460   ASSERT_TRUE(output.allclose(expected));
461 }
462 
TEST_F(FunctionalTest,PDist)463 TEST_F(FunctionalTest, PDist) {
464   {
465     auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}});
466     auto output = F::pdist(input);
467     auto expected = torch::tensor({11.7898});
468     ASSERT_TRUE(output.allclose(expected));
469   }
470   {
471     auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}});
472     auto output = F::pdist(input, 1.5);
473     auto expected = torch::tensor({4.0, 4.8945, 2.0});
474     ASSERT_TRUE(output.allclose(expected));
475   }
476 }
477 
TEST_F(FunctionalTest,AdaptiveMaxPool1d)478 TEST_F(FunctionalTest, AdaptiveMaxPool1d) {
479   auto x = torch::ones({1, 1, 5});
480   auto y = F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3));
481 
482   ASSERT_EQ(y.ndimension(), 3);
483   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
484   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
485 }
486 
TEST_F(FunctionalTest,AdaptiveMaxPool2d)487 TEST_F(FunctionalTest, AdaptiveMaxPool2d) {
488   auto x = torch::ones({2, 5, 5});
489   auto y = F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3));
490 
491   ASSERT_EQ(y.ndimension(), 3);
492   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
493   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
494 }
495 
TEST_F(FunctionalTest,AdaptiveMaxPool3d)496 TEST_F(FunctionalTest, AdaptiveMaxPool3d) {
497   auto x = torch::ones({2, 5, 5, 5});
498   auto y = F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3));
499 
500   ASSERT_EQ(y.ndimension(), 4);
501   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
502   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
503 }
504 
TEST_F(FunctionalTest,AdaptiveAvgPool1d)505 TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
506   auto x = torch::ones({1, 1, 5});
507   auto y = F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3));
508 
509   ASSERT_EQ(y.ndimension(), 3);
510   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
511   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
512 }
513 
TEST_F(FunctionalTest,AdaptiveAvgPool2d)514 TEST_F(FunctionalTest, AdaptiveAvgPool2d) {
515   auto x = torch::ones({2, 5, 5});
516   auto y = F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3));
517 
518   ASSERT_EQ(y.ndimension(), 3);
519   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
520   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
521 }
522 
TEST_F(FunctionalTest,AdaptiveAvgPool3d)523 TEST_F(FunctionalTest, AdaptiveAvgPool3d) {
524   auto x = torch::ones({2, 5, 5, 5});
525   auto y = F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3));
526 
527   ASSERT_EQ(y.ndimension(), 4);
528   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
529   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
530 }
531 
TEST_F(FunctionalTest,L1Loss)532 TEST_F(FunctionalTest, L1Loss) {
533   auto input = torch::randn({5, 6}, torch::requires_grad());
534   auto target = torch::empty({5, 6}).random_(2);
535   auto output = F::l1_loss(torch::sigmoid(input), target);
536   auto s = output.sum();
537   s.backward();
538 
539   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
540   ASSERT_EQ(input.sizes(), input.grad().sizes());
541 }
542 
TEST_F(FunctionalTest,MSELoss)543 TEST_F(FunctionalTest, MSELoss) {
544   auto input = torch::randn({5, 6}, torch::requires_grad());
545   auto target = torch::empty({5, 6}).random_(2);
546   auto output = F::mse_loss(torch::sigmoid(input), target);
547   auto s = output.sum();
548   s.backward();
549 
550   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
551   ASSERT_EQ(input.sizes(), input.grad().sizes());
552 }
553 
TEST_F(FunctionalTest,BCELoss)554 TEST_F(FunctionalTest, BCELoss) {
555   auto input = torch::randn({5, 6}, torch::requires_grad());
556   auto target = torch::empty({5, 6}).random_(2);
557   auto output = F::binary_cross_entropy(torch::sigmoid(input), target);
558   auto s = output.sum();
559   s.backward();
560 
561   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
562   ASSERT_EQ(input.sizes(), input.grad().sizes());
563 }
564 
TEST_F(FunctionalTest,KLDivLoss)565 TEST_F(FunctionalTest, KLDivLoss) {
566   KLDivLoss loss;
567   auto input = torch::randn({5, 6}, torch::requires_grad());
568   auto target = torch::empty({5, 6}).random_(2);
569   auto output = F::kl_div(torch::sigmoid(input), target);
570   auto s = output.sum();
571   s.backward();
572 
573   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
574   ASSERT_EQ(input.sizes(), input.grad().sizes());
575 }
576 
TEST_F(FunctionalTest,HingeEmbeddingLoss)577 TEST_F(FunctionalTest, HingeEmbeddingLoss) {
578   auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat);
579   auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
580   auto output = F::hinge_embedding_loss(
581       input, target, F::HingeEmbeddingLossFuncOptions().margin(2));
582   auto expected = torch::tensor({10}, torch::kFloat);
583 
584   ASSERT_TRUE(output.allclose(expected));
585 }
586 
TEST_F(FunctionalTest,GridSample)587 TEST_F(FunctionalTest, GridSample) {
588   auto input =
589       torch::arange(9, torch::kFloat).view(std::vector<int64_t>({1, 1, 3, 3}));
590   auto grid = torch::tensor(
591       {{{{-2., -1.}, {-1., -1.}, {0., -1.}},
592         {{-1., 0.}, {0., 0.}, {1., 0.}},
593         {{0., 1.}, {1., 1.}, {2., 1.}}}},
594       torch::kFloat);
595 
596   // bilinear, zeros, true
597   auto options = F::GridSampleFuncOptions()
598                      .mode(torch::kBilinear)
599                      .padding_mode(torch::kZeros)
600                      .align_corners(true);
601   auto output = F::grid_sample(input, grid, options);
602   auto expected = torch::tensor(
603       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
604 
605   ASSERT_TRUE(output.allclose(expected));
606 
607   // bilinear, zeros, false
608   options = F::GridSampleFuncOptions()
609                 .mode(torch::kBilinear)
610                 .padding_mode(torch::kZeros)
611                 .align_corners(false);
612   output = F::grid_sample(input, grid, options);
613   expected = torch::tensor(
614       {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat);
615 
616   ASSERT_TRUE(output.allclose(expected));
617 
618   // default options (bilinear, zeros, false) same result as above
619   output = F::grid_sample(input, grid);
620 
621   ASSERT_TRUE(output.allclose(expected));
622 
623   // nearest, zeros, true
624   options = F::GridSampleFuncOptions()
625                 .mode(torch::kNearest)
626                 .padding_mode(torch::kZeros)
627                 .align_corners(true);
628   output = F::grid_sample(input, grid, options);
629   expected = torch::tensor(
630       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
631 
632   ASSERT_TRUE(output.allclose(expected));
633 
634   // bilinear, border, true
635   options = F::GridSampleFuncOptions()
636                 .mode(torch::kBilinear)
637                 .padding_mode(torch::kBorder)
638                 .align_corners(true);
639   output = F::grid_sample(input, grid, options);
640   expected = torch::tensor(
641       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat);
642 
643   ASSERT_TRUE(output.allclose(expected));
644 
645   // bilinear, reflection, true
646   options = F::GridSampleFuncOptions()
647                 .mode(torch::kBilinear)
648                 .padding_mode(torch::kReflection)
649                 .align_corners(true);
650   output = F::grid_sample(input, grid, options);
651   expected = torch::tensor(
652       {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat);
653 
654   ASSERT_TRUE(output.allclose(expected));
655 }
656 
TEST_F(FunctionalTest,AffineGrid)657 TEST_F(FunctionalTest, AffineGrid) {
658   {
659     // 2D affine.
660     auto theta = torch::arange(1., 13).view(std::vector<int64_t>({2, 2, 3}));
661     auto size = std::vector<int64_t>({2, 3, 2, 2});
662     auto align_corners = true;
663     auto output = F::affine_grid(theta, size, !align_corners);
664     auto expected = torch::tensor(
665         {{{{1.50, 1.50}, {2.50, 5.50}}, {{3.50, 6.50}, {4.50, 10.50}}},
666          {{{1.50, 1.50}, {8.50, 11.50}}, {{9.50, 12.50}, {16.50, 22.50}}}});
667     auto output_aligned = F::affine_grid(theta, size, align_corners);
668     auto expected_aligned = torch::tensor(
669         {{{{0.0, -3.0}, {2.0, 5.0}}, {{4.0, 7.0}, {6.0, 15.0}}},
670          {{{-6.0, -9.0}, {8.0, 11.0}}, {{10.0, 13.0}, {24.0, 33.0}}}});
671 
672     ASSERT_TRUE(output.allclose(expected));
673     ASSERT_TRUE(output_aligned.allclose(expected_aligned));
674   }
675   {
676     // 3D affine.
677     auto theta = torch::arange(1., 13).view(std::vector<int64_t>({1, 3, 4}));
678     auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
679     auto align_corners = true;
680     auto output = F::affine_grid(theta, size, !align_corners);
681     auto expected = torch::tensor(
682         {{{{{0.5000, -2.1667, -4.8333}, {1.5000, 2.8333, 4.1667}},
683            {{2.5000, 3.8333, 5.1667}, {3.5000, 8.8333, 14.1667}}},
684           {{{2.5000, 2.5000, 2.5000}, {3.5000, 7.5000, 11.5000}},
685            {{4.5000, 8.5000, 12.5000}, {5.5000, 13.5000, 21.5000}}},
686           {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}},
687            {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}});
688     auto output_aligned = F::affine_grid(theta, size, align_corners);
689     auto expected_aligned = torch::tensor(
690         {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}},
691            {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}},
692           {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}},
693            {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}},
694           {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}},
695            {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}});
696 
697     ASSERT_TRUE(output.allclose(expected, 1e-2));
698     ASSERT_TRUE(output_aligned.allclose(expected_aligned));
699   }
700   {
701     auto theta = torch::empty({1, 2, 3}, torch::kDouble);
702     auto size = std::vector<int64_t>({1, 1, 2, 2});
703     ASSERT_THROWS_WITH(
704         F::affine_grid(torch::empty({2, 2, 3}), {-1, 1, 2, 2}),
705         "Expected non-zero, positive output size. Got [-1, 1, 2, 2]");
706     ASSERT_THROWS_WITH(
707         F::affine_grid(torch::empty({2, 2, 3}, torch::kInt), size),
708         "Expected theta to have floating point type, but got int");
709     ASSERT_THROWS_WITH(
710         F::affine_grid(theta[0], size),
711         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
712         "[1, 1, 2, 2]. Got [2, 3].");
713     ASSERT_THROWS_WITH(
714         F::affine_grid(theta.unsqueeze(0), size),
715         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
716         "[1, 1, 2, 2]. Got [1, 1, 2, 3].");
717     ASSERT_THROWS_WITH(
718         F::affine_grid(theta.repeat({1, 2, 1}), size),
719         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
720         "[1, 1, 2, 2]. Got [1, 4, 3].");
721     ASSERT_THROWS_WITH(
722         F::affine_grid(theta.repeat({1, 1, 2}), size),
723         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
724         "[1, 1, 2, 2]. Got [1, 2, 6].");
725   }
726   {
727     auto theta = torch::empty({1, 3, 4}, torch::kDouble);
728     auto size = std::vector<int64_t>({1, 1, 2, 2, 3});
729     ASSERT_THROWS_WITH(
730         F::affine_grid(theta[0], size),
731         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
732         "[1, 1, 2, 2, 3]. Got [3, 4].");
733     ASSERT_THROWS_WITH(
734         F::affine_grid(theta.unsqueeze(0), size),
735         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
736         "[1, 1, 2, 2, 3]. Got [1, 1, 3, 4].");
737     ASSERT_THROWS_WITH(
738         F::affine_grid(theta.repeat({1, 2, 1}), size),
739         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
740         "[1, 1, 2, 2, 3]. Got [1, 6, 4].");
741     ASSERT_THROWS_WITH(
742         F::affine_grid(theta.repeat({1, 1, 2}), size),
743         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
744         "[1, 1, 2, 2, 3]. Got [1, 3, 8].");
745     ASSERT_THROWS_WITH(
746         F::affine_grid(theta, {1, 1, 1, 2, 2, 3}),
747         "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
748         "transforms, respectively. Got size [1, 1, 1, 2, 2, 3]");
749     ASSERT_THROWS_WITH(
750         F::affine_grid(theta, {1, 1}),
751         "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
752         "transforms, respectively. Got size [1, 1]");
753   }
754 }
755 
TEST_F(FunctionalTest,MultiMarginLoss)756 TEST_F(FunctionalTest, MultiMarginLoss) {
757   auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
758   auto input = torch::tensor(
759       {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
760       torch::dtype(torch::kFloat).requires_grad(true));
761   auto target = torch::tensor({2, 1, 0}, torch::kLong);
762   auto output = F::multi_margin_loss(
763       input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight));
764   auto expected = torch::tensor({0.305556}, torch::kFloat);
765 
766   ASSERT_TRUE(output.allclose(expected, 1e-04));
767 }
768 
TEST_F(FunctionalTest,CosineEmbeddingLoss)769 TEST_F(FunctionalTest, CosineEmbeddingLoss) {
770   auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}});
771   auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}});
772   auto target = torch::tensor({1, -1});
773   auto output = F::cosine_embedding_loss(
774       input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5));
775   auto expected = torch::tensor({0.1004}, torch::kFloat);
776 
777   ASSERT_TRUE(output.allclose(expected, 1e-4));
778 }
779 
TEST_F(FunctionalTest,MultiLabelMarginLossDefaultOptions)780 TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) {
781   auto input = torch::tensor(
782       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
783   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
784   auto output = F::multilabel_margin_loss(input, target);
785   auto expected = torch::tensor({0.8500}, torch::kFloat);
786   auto s = output.sum();
787   s.backward();
788 
789   ASSERT_TRUE(output.allclose(expected));
790   ASSERT_EQ(input.sizes(), input.grad().sizes());
791 }
792 
TEST_F(FunctionalTest,MultiLabelMarginLossNoReduction)793 TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) {
794   auto input = torch::tensor(
795       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
796   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
797   auto output = F::multilabel_margin_loss(input, target, torch::kNone);
798   auto expected = torch::tensor({0.8500}, torch::kFloat);
799   auto s = output.sum();
800   s.backward();
801 
802   ASSERT_TRUE(output.allclose(expected));
803   ASSERT_EQ(input.sizes(), input.grad().sizes());
804 }
805 
TEST_F(FunctionalTest,TripletMarginLoss)806 TEST_F(FunctionalTest, TripletMarginLoss) {
807   auto anchor = torch::tensor({{3., 3.}}, torch::kFloat);
808   auto positive = torch::tensor({{2., 2.}}, torch::kFloat);
809   auto negative = torch::tensor({{0., 0.}}, torch::kFloat);
810   auto output = F::triplet_margin_loss(
811       anchor,
812       positive,
813       negative,
814       F::TripletMarginLossFuncOptions().margin(1.0));
815   auto expected = torch::tensor({0.}, torch::kFloat);
816 
817   ASSERT_TRUE(output.allclose(expected, 1e-04));
818 }
819 
TEST_F(FunctionalTest,TripletMarginWithDistanceLossDefaultParity)820 TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
821   // Check that if we use torch::pairwise_distance with the default
822   // TripletMarginLoss options as our distance function, the outputs
823   // are equal (i.e., equal under defaults).
824 
825   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
826       torch::kSum, torch::kMean, torch::kNone};
827   std::vector<float> margins = {0.5, 1.0, 1.5};
828   std::vector<bool> swaps = {true, false};
829 
830   for (auto& reduction : reductions) {
831     for (auto& margin : margins) {
832       for (const auto& swap : swaps) {
833         auto anchor = torch::randn(
834             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
835         auto positive = torch::randn(
836             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
837         auto negative = torch::randn(
838             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
839 
840         auto basicOptions = F::TripletMarginLossFuncOptions()
841                                 .reduction(reduction)
842                                 .margin(margin)
843                                 .swap(swap);
844         auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions()
845                                    .reduction(reduction)
846                                    .margin(margin)
847                                    .swap(swap);
848         TripletMarginLoss basicLoss(basicOptions);
849         TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
850 
851         auto basicOutput =
852             F::triplet_margin_loss(anchor, positive, negative, basicOptions);
853         auto distanceOutput = F::triplet_margin_with_distance_loss(
854             anchor, positive, negative, distanceOptions);
855 
856         ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
857 
858         // handle for torch::kNone reduction
859         auto sum = distanceOutput.sum();
860         sum.backward();
861         ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
862         ASSERT_EQ(positive.sizes(), positive.grad().sizes());
863         ASSERT_EQ(negative.sizes(), negative.grad().sizes());
864       }
865     }
866   }
867 }
868 
TEST_F(FunctionalTest,NLLLoss)869 TEST_F(FunctionalTest, NLLLoss) {
870   auto input = torch::tensor(
871       {{-0.1315, -3.1315, -2.5315},
872        {-3.7038, -0.1038, -2.6038},
873        {-2.3422, -1.3422, -0.4422}},
874       torch::kFloat);
875   auto target = torch::tensor({1, 0, 2}, torch::kLong);
876   auto output = F::nll_loss(
877       input,
878       target,
879       F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
880   auto expected = torch::tensor(2.4258, torch::kFloat);
881   ASSERT_TRUE(output.allclose(expected, 1e-04));
882   ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04));
883 }
884 
TEST_F(FunctionalTest,CrossEntropy)885 TEST_F(FunctionalTest, CrossEntropy) {
886   auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat);
887   auto target = torch::tensor({0, 1}, torch::kLong);
888   auto output = F::cross_entropy(
889       input,
890       target,
891       F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
892   auto expected = torch::tensor(0.6931, torch::kFloat);
893 
894   ASSERT_TRUE(output.allclose(expected, 1e-04));
895   ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04));
896 
897   // label smoothing with class indices
898   input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat);
899   output = F::cross_entropy(
900       input,
901       target,
902       F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(
903           torch::kMean));
904   expected = torch::tensor(0.3326, torch::kFloat);
905   ASSERT_TRUE(output.allclose(expected, 1e-04));
906 
907   // label smoothing with target probabilities
908   target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
909   output = F::cross_entropy(
910       input,
911       target,
912       F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(
913           torch::kMean));
914   expected = torch::tensor(0.5701, torch::kFloat);
915   ASSERT_TRUE(output.allclose(expected, 1e-04));
916 }
917 
TEST_F(FunctionalTest,MaxUnpool1d)918 TEST_F(FunctionalTest, MaxUnpool1d) {
919   auto x = torch::tensor(
920       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
921   auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
922   auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
923 
924   ASSERT_EQ(y.ndimension(), 3);
925   ASSERT_TRUE(torch::allclose(
926       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
927   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
928 
929   x = torch::tensor(
930       {{2, 4, 5}}, torch::dtype(torch::kFloat).requires_grad(true));
931   indices = torch::tensor({{1, 3, 4}}, torch::kLong);
932   y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
933 
934   ASSERT_EQ(y.ndimension(), 2);
935   ASSERT_TRUE(torch::allclose(
936       y, torch::tensor({{0, 2, 0, 4, 5, 0, 0, 0, 0}}, torch::kFloat)));
937   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 9}));
938 
939   x = torch::tensor(
940       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
941   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
942   y = F::max_unpool1d(
943       x,
944       indices,
945       F::MaxUnpool1dFuncOptions(3).output_size(
946           std::vector<int64_t>({1, 1, 9})));
947 
948   ASSERT_EQ(y.ndimension(), 3);
949   ASSERT_TRUE(torch::allclose(
950       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
951   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
952 
953   x = torch::tensor(
954       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
955   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
956   y = F::max_unpool1d(
957       x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1));
958 
959   ASSERT_EQ(y.ndimension(), 3);
960   ASSERT_TRUE(
961       torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
962   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
963 }
964 
TEST_F(FunctionalTest,MaxUnpool2d)965 TEST_F(FunctionalTest, MaxUnpool2d) {
966   auto indices = torch::tensor(
967       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
968        {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
969       torch::kLong);
970   auto x = torch::tensor(
971       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
972        {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
973       torch::dtype(torch::kFloat).requires_grad(true));
974   auto y = F::max_unpool2d(
975       x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
976 
977   ASSERT_EQ(y.dim(), 4);
978   ASSERT_TRUE(torch::allclose(
979       y,
980       torch::tensor(
981           {{{{0, 0, 0, 0, 0},
982              {0, 6, 0, 8, 9},
983              {0, 0, 0, 0, 0},
984              {0, 16, 0, 18, 19},
985              {0, 21, 0, 23, 24}}},
986            {{{0, 0, 0, 0, 0},
987              {0, 31, 0, 33, 34},
988              {0, 0, 0, 0, 0},
989              {0, 41, 0, 43, 44},
990              {0, 46, 0, 48, 49}}}},
991           torch::kFloat)));
992   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
993 
994   indices = torch::tensor(
995       {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
996        {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
997       torch::kLong);
998   x = torch::tensor(
999       {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
1000        {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}},
1001       torch::dtype(torch::kFloat).requires_grad(true));
1002   y = F::max_unpool2d(
1003       x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
1004 
1005   ASSERT_EQ(y.dim(), 3);
1006   ASSERT_TRUE(torch::allclose(
1007       y,
1008       torch::tensor(
1009           {{{0, 0, 0, 0, 0},
1010             {0, 6, 0, 8, 9},
1011             {0, 0, 0, 0, 0},
1012             {0, 16, 0, 18, 19},
1013             {0, 21, 0, 23, 24}},
1014            {{0, 0, 0, 0, 0},
1015             {0, 31, 0, 33, 34},
1016             {0, 0, 0, 0, 0},
1017             {0, 41, 0, 43, 44},
1018             {0, 46, 0, 48, 49}}},
1019           torch::kFloat)));
1020   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 5, 5}));
1021 }
1022 
TEST_F(FunctionalTest,MaxUnpool3d)1023 TEST_F(FunctionalTest, MaxUnpool3d) {
1024   auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1025   auto x = torch::tensor(
1026       {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1027   auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1028 
1029   ASSERT_EQ(y.dim(), 5);
1030   ASSERT_TRUE(torch::allclose(
1031       y,
1032       torch::tensor(
1033           {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1034              {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1035              {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1036           torch::kFloat)));
1037   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1038 
1039   indices = torch::tensor({{{{26}}}}, torch::kLong);
1040   x = torch::tensor(
1041       {{{{26}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1042   y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1043 
1044   ASSERT_EQ(y.dim(), 4);
1045   ASSERT_TRUE(torch::allclose(
1046       y,
1047       torch::tensor(
1048           {{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1049             {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1050             {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}},
1051           torch::kFloat)));
1052   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
1053 }
1054 
TEST_F(FunctionalTest,ELU)1055 TEST_F(FunctionalTest, ELU) {
1056   const auto size = 3;
1057   for (const auto inplace : {false, true}) {
1058     for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1059       auto x = torch::linspace(-10.0, 10.0, size * size * size);
1060       x.resize_({size, size, size});
1061       auto x_bf16 =
1062           torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16);
1063       x_bf16.resize_({size, size, size});
1064 
1065       auto y_exp = torch::max(torch::zeros_like(x), x) +
1066           torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
1067       auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1068       auto y_bf16 =
1069           F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1070 
1071       ASSERT_EQ(y.ndimension(), 3);
1072       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1073       ASSERT_TRUE(torch::allclose(y, y_exp));
1074       ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1075       if (inplace) {
1076         ASSERT_TRUE(torch::allclose(x, y_exp));
1077         ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1078       }
1079     }
1080   }
1081   ASSERT_TRUE(F::elu(torch::tensor(1.)).defined());
1082 }
1083 
TEST_F(FunctionalTest,SELU)1084 TEST_F(FunctionalTest, SELU) {
1085   {
1086     const double scale = 1.0507009873554804934193349852946;
1087     const double alpha = 1.6732632423543772848170429916717;
1088     for (const auto inplace : {false, true}) {
1089       auto input = torch::randn({5, 5});
1090       auto input_bf16 = input.clone().to(torch::kBFloat16);
1091       auto expected = scale *
1092           (torch::max(torch::zeros_like(input), input) +
1093            torch::min(
1094                torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
1095       auto output = F::selu(input, inplace);
1096       auto output_bf16 = F::selu(input_bf16, inplace);
1097 
1098       ASSERT_TRUE(output.allclose(expected));
1099       ASSERT_TRUE(output_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1100       if (inplace) {
1101         ASSERT_TRUE(input.allclose(expected));
1102         ASSERT_TRUE(input_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1103       }
1104     }
1105   }
1106   {
1107     auto input = torch::arange(0, 9, torch::kDouble).view({3, 3});
1108     auto output = F::selu(input);
1109     auto expected = F::selu(input, false);
1110     ASSERT_TRUE(output.allclose(expected));
1111   }
1112   ASSERT_TRUE(F::selu(torch::tensor(1.)).defined());
1113 }
1114 
TEST_F(FunctionalTest,GLU)1115 TEST_F(FunctionalTest, GLU) {
1116   int64_t dim = 1;
1117   auto input = torch::randn({4, 2}, torch::requires_grad());
1118   auto output = F::glu(input, dim);
1119   auto input_size = input.sizes()[dim] / 2;
1120   auto first_half = input.narrow(dim, 0, input_size);
1121   auto second_half = input.narrow(dim, input_size, input_size);
1122   auto expected = first_half * torch::sigmoid(second_half);
1123 
1124   ASSERT_TRUE(output.allclose(expected));
1125   ASSERT_TRUE(F::glu(input).allclose(expected));
1126 }
1127 
TEST_F(FunctionalTest,GELU)1128 TEST_F(FunctionalTest, GELU) {
1129   const auto x = torch::linspace(-3.0, 3.0, 100);
1130   const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
1131   const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none"));
1132   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1133 }
1134 
TEST_F(FunctionalTest,TanhGELU)1135 TEST_F(FunctionalTest, TanhGELU) {
1136   const auto x = torch::linspace(-3.0, 3.0, 100);
1137   const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
1138   const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
1139   const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh"));
1140   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1141 }
1142 
TEST_F(FunctionalTest,Hardshrink)1143 TEST_F(FunctionalTest, Hardshrink) {
1144   const auto size = 3;
1145   for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
1146     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1147     x.resize_({size, size, size}).set_requires_grad(true);
1148     auto y = F::hardshrink(x, F::HardshrinkFuncOptions().lambda(lambda));
1149     torch::Tensor s = y.sum();
1150 
1151     s.backward();
1152     ASSERT_EQ(s.ndimension(), 0);
1153 
1154     ASSERT_EQ(y.ndimension(), 3);
1155     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1156     auto y_exp = (x.abs() > lambda) * x;
1157     ASSERT_TRUE(torch::allclose(y, y_exp));
1158   }
1159 }
1160 
TEST_F(FunctionalTest,OneHot)1161 TEST_F(FunctionalTest, OneHot) {
1162   { // Test #1
1163     auto x = torch::arange(0, 5, torch::kLong);
1164     auto y = F::one_hot(x % 3);
1165     auto expected = torch::tensor(
1166         {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong);
1167 
1168     ASSERT_EQ(y.ndimension(), 2);
1169     ASSERT_TRUE(torch::allclose(y, expected));
1170     ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 3}));
1171   }
1172 
1173   { // Test #2
1174     auto x = torch::arange(0, 5, torch::kLong);
1175     auto y = F::one_hot(x % 3, 5);
1176     auto expected = torch::tensor(
1177         {{1, 0, 0, 0, 0},
1178          {0, 1, 0, 0, 0},
1179          {0, 0, 1, 0, 0},
1180          {1, 0, 0, 0, 0},
1181          {0, 1, 0, 0, 0}},
1182         torch::kLong);
1183 
1184     ASSERT_EQ(y.ndimension(), 2);
1185     ASSERT_TRUE(torch::allclose(y, expected));
1186     ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 5}));
1187   }
1188 
1189   { // Test #3
1190     auto x = torch::arange(0, 6, torch::kLong);
1191     auto y = F::one_hot(x.view(std::vector<int64_t>({3, 2})) % 3);
1192     auto expected = torch::tensor(
1193         {{{1, 0, 0}, {0, 1, 0}},
1194          {{0, 0, 1}, {1, 0, 0}},
1195          {{0, 1, 0}, {0, 0, 1}}},
1196         torch::kLong);
1197 
1198     ASSERT_EQ(y.ndimension(), 3);
1199     ASSERT_TRUE(torch::allclose(y, expected));
1200     ASSERT_EQ(y.sizes(), std::vector<int64_t>({3, 2, 3}));
1201   }
1202 }
1203 
TEST_F(FunctionalTest,Hardtanh)1204 TEST_F(FunctionalTest, Hardtanh) {
1205   const auto size = 3;
1206   for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
1207     for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) {
1208       for (const auto inplace : {false, true}) {
1209         auto x = torch::linspace(-10.0, 10.0, size * size * size);
1210         x.resize_({size, size, size});
1211         auto y_exp = (x < min_val) * min_val +
1212             ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val;
1213         auto y = F::hardtanh(
1214             x,
1215             F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace(
1216                 inplace));
1217 
1218         ASSERT_EQ(y.ndimension(), 3);
1219         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1220         ASSERT_TRUE(torch::allclose(y, y_exp));
1221         if (inplace) {
1222           ASSERT_TRUE(torch::allclose(x, y_exp));
1223         }
1224       }
1225     }
1226   }
1227   ASSERT_TRUE(F::hardtanh(torch::tensor(1.)).defined());
1228 }
1229 
TEST_F(FunctionalTest,LeakyReLU)1230 TEST_F(FunctionalTest, LeakyReLU) {
1231   const auto size = 3;
1232   for (const auto negative_slope : {0.0, 0.42, 1.0}) {
1233     for (const auto inplace : {false, true}) {
1234       for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1235         auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1236         x.resize_({size, size, size});
1237         auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
1238         auto y = F::leaky_relu(
1239             x,
1240             F::LeakyReLUFuncOptions()
1241                 .negative_slope(negative_slope)
1242                 .inplace(inplace));
1243 
1244         ASSERT_EQ(y.ndimension(), 3);
1245         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1246         ASSERT_TRUE(torch::allclose(y, y_exp));
1247         if (inplace) {
1248           ASSERT_TRUE(torch::allclose(x, y_exp));
1249         }
1250       }
1251     }
1252   }
1253   ASSERT_TRUE(F::leaky_relu(torch::tensor(1.)).defined());
1254 }
1255 
TEST_F(FunctionalTest,LogSigmoid)1256 TEST_F(FunctionalTest, LogSigmoid) {
1257   const auto size = 3;
1258   LogSigmoid model;
1259   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1260   x.resize_({size, size, size});
1261   auto y = F::logsigmoid(x);
1262 
1263   ASSERT_EQ(y.ndimension(), 3);
1264   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1265   auto y_exp = torch::log(
1266       torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
1267   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1268 }
1269 
TEST_F(FunctionalTest,GumbelSoftmax)1270 TEST_F(FunctionalTest, GumbelSoftmax) {
1271   // Test 1: No-options
1272   {
1273     auto logits = torch::randn({5});
1274     int expected_count = 1;
1275     auto y_draw = F::gumbel_softmax(logits);
1276 
1277     // All values positive
1278     ASSERT_GE(y_draw.min().item<int>(), 0);
1279     // Shape unchanged
1280     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1281     // One choice per draw
1282     ASSERT_TRUE(torch::allclose(
1283         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1284   }
1285 
1286   // Test 2: 1D shape, 0 and -1 dim
1287   for (const auto dim : {0, -1}) {
1288     auto logits = torch::randn({5});
1289     int expected_count = 1;
1290     auto y_draw = F::gumbel_softmax(
1291         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim));
1292 
1293     // All values positive
1294     ASSERT_GE(y_draw.min().item<int>(), 0);
1295     // Shape unchanged
1296     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1297     // One choice per draw
1298     ASSERT_TRUE(torch::allclose(
1299         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1300   }
1301 
1302   { // Test 3: 2D shape, 1 dim
1303     auto logits = torch::randn({5, 4});
1304     int expected_count = 5;
1305     auto y_draw = F::gumbel_softmax(
1306         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1));
1307 
1308     // All values positive
1309     ASSERT_GE(y_draw.min().item<int>(), 0);
1310     // Shape unchanged
1311     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1312     // One choice per draw
1313     ASSERT_TRUE(torch::allclose(
1314         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1315   }
1316 
1317   // Test 4: 3D shape, 1 and -1 dim
1318   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1319   int dims[] = {1, -1};
1320   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers)
1321   int expected[] = {5 * 3, 5 * 4};
1322   for (const auto i : c10::irange(2)) {
1323     auto logits = torch::randn({5, 4, 3});
1324     int expected_count = expected[i];
1325     auto y_draw = F::gumbel_softmax(
1326         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i]));
1327 
1328     // All values positive
1329     ASSERT_GE(y_draw.min().item<int>(), 0);
1330     // Shape unchanged
1331     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1332     // One choice per draw
1333     ASSERT_TRUE(torch::allclose(
1334         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1335   }
1336 
1337   { // Test 5: Straight through
1338     int num_draws = 100;
1339     auto logits = torch::tensor({{0.2, 0.8, 0.1}});
1340     logits = logits.reshape({1, 3});
1341     logits.requires_grad();
1342     auto probs = logits.softmax(-1);
1343 
1344     auto counts = torch::zeros_like(logits);
1345     torch::Tensor y_draw;
1346     for (const auto i : c10::irange(num_draws)) {
1347       (void)i; // Suppress unused variable warning
1348       y_draw =
1349           F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true));
1350       counts += y_draw;
1351     }
1352 
1353     // All values positive
1354     ASSERT_GE(y_draw.min().item<int>(), 0);
1355     // Each experiment should result in 1 draw
1356     ASSERT_EQ(counts.sum().item<int>(), num_draws);
1357 
1358     // Check results are asymptotically as expected
1359     auto expected = probs * num_draws;
1360     // ~z is approximately N(0,1) for unbiased count
1361     auto z = (counts - expected) / (expected * (1 - probs)).sqrt();
1362     // A (lazy) approximate 99% two-sided test:
1363     // occurs with prob alpha~>=0.01 if unbiased
1364     ASSERT_LT(z.abs().max().item<float>(), 2.58);
1365   }
1366 }
1367 
TEST_F(FunctionalTest,Softmax)1368 TEST_F(FunctionalTest, Softmax) {
1369   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1370   // NOLINTNEXTLINE(bugprone-argument-comment)
1371   auto output = F::softmax(input, /*dim=*/1);
1372   auto sum = torch::sum(torch::exp(input), 1);
1373 
1374   for (const auto i : c10::irange(2)) {
1375     auto expected = torch::exp(input[i]) / sum[i];
1376     ASSERT_TRUE(torch::allclose(output[i], expected));
1377   }
1378 }
1379 
TEST_F(FunctionalTest,Softmin)1380 TEST_F(FunctionalTest, Softmin) {
1381   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1382   // NOLINTNEXTLINE(bugprone-argument-comment)
1383   auto output = F::softmin(input, /*dim=*/1);
1384   auto sum = torch::sum(torch::exp(-input), 1);
1385 
1386   for (const auto i : c10::irange(2)) {
1387     auto expected = torch::exp(-input[i]) / sum[i];
1388     ASSERT_TRUE(torch::allclose(output[i], expected));
1389   }
1390 }
1391 
TEST_F(FunctionalTest,LogSoftmax)1392 TEST_F(FunctionalTest, LogSoftmax) {
1393   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1394   // NOLINTNEXTLINE(bugprone-argument-comment)
1395   auto output = F::log_softmax(input, /*dim=*/1);
1396   auto sum = torch::sum(torch::exp(input), 1);
1397 
1398   for (const auto i : c10::irange(2)) {
1399     auto expected = torch::log(torch::exp(input[i]) / sum[i]);
1400     ASSERT_TRUE(torch::allclose(output[i], expected));
1401   }
1402 }
1403 
TEST_F(FunctionalTest,PReLU)1404 TEST_F(FunctionalTest, PReLU) {
1405   const auto x = torch::rand({42, 24}) * 200 - 100;
1406   const auto w = torch::rand(24) * 200 - 100;
1407   const auto y = F::prelu(x, w);
1408   ASSERT_EQ(y.sizes(), std::vector<int64_t>({42, 24}));
1409   const auto y_exp = (x < 0) * w * x + (x >= 0) * x;
1410   ASSERT_TRUE(torch::allclose(y, y_exp));
1411 }
1412 
TEST_F(FunctionalTest,LayerNorm)1413 TEST_F(FunctionalTest, LayerNorm) {
1414   const auto input = torch::randn({2, 2});
1415   auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
1416   auto y_exp =
1417       torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5);
1418   ASSERT_TRUE(torch::allclose(y, y_exp));
1419 }
1420 
TEST_F(FunctionalTest,GroupNorm)1421 TEST_F(FunctionalTest, GroupNorm) {
1422   const auto input = torch::randn({2, 2});
1423   auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
1424   auto y_exp =
1425       torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5);
1426   ASSERT_TRUE(torch::allclose(y, y_exp));
1427 }
1428 
TEST_F(FunctionalTest,LocalResponseNorm)1429 TEST_F(FunctionalTest, LocalResponseNorm) {
1430   const auto x = torch::arange(100, 118).resize_({3, 3, 2});
1431   const auto y = F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
1432   ASSERT_EQ(y.ndimension(), 3);
1433   ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
1434   const auto y_exp = torch::tensor(
1435       {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}},
1436        {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}},
1437        {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}},
1438       torch::kFloat);
1439   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1440 }
1441 
TEST_F(FunctionalTest,Linear)1442 TEST_F(FunctionalTest, Linear) {
1443   {
1444     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1445     const auto w = torch::arange(200., 206).resize_({3, 2});
1446     const auto b = torch::arange(300., 303);
1447     const auto y = F::linear(x, w, b);
1448     ASSERT_EQ(y.ndimension(), 3);
1449     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1450     const auto y_exp = torch::tensor(
1451         {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
1452          {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
1453          {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
1454         torch::kFloat);
1455     ASSERT_TRUE(torch::allclose(y, y_exp));
1456   }
1457   {
1458     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1459     const auto w = torch::arange(200., 206).resize_({3, 2});
1460     const auto y = F::linear(x, w);
1461     ASSERT_EQ(y.ndimension(), 3);
1462     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1463     const auto y_exp = torch::tensor(
1464         {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
1465          {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
1466          {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
1467         torch::kFloat);
1468     ASSERT_TRUE(torch::allclose(y, y_exp));
1469   }
1470 }
1471 
TEST_F(FunctionalTest,Embedding)1472 TEST_F(FunctionalTest, Embedding) {
1473   const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1474   auto weight = torch::empty({10, 3});
1475   torch::nn::init::normal_(weight);
1476   auto y = F::embedding(input, weight);
1477   auto y_exp = torch::embedding(weight, input.contiguous(), -1, false, false);
1478   ASSERT_TRUE(torch::allclose(y, y_exp));
1479 }
1480 
TEST_F(FunctionalTest,EmbeddingBag)1481 TEST_F(FunctionalTest, EmbeddingBag) {
1482   const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong);
1483   auto offsets = torch::tensor({0, 4}, torch::kLong);
1484   auto weight = torch::empty({10, 3});
1485   torch::nn::init::normal_(weight);
1486   auto y = F::embedding_bag(
1487       input,
1488       weight,
1489       F::EmbeddingBagFuncOptions()
1490           .mode(torch::kSum)
1491           .offsets(offsets)
1492           .padding_idx(4));
1493   auto y_exp = std::get<0>(torch::embedding_bag(
1494       weight, input, offsets, false, 0, false, torch::Tensor(), false, 4));
1495   ASSERT_TRUE(torch::allclose(y, y_exp));
1496 
1497   // no options test
1498   const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1499   auto offsets_ = torch::arange(
1500       0,
1501       input_.numel(),
1502       input_.size(1),
1503       torch::TensorOptions().dtype(torch::kLong).device(input.device()));
1504   y = F::embedding_bag(input_, weight);
1505   y_exp = std::get<0>(torch::embedding_bag(
1506       weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor()));
1507   ASSERT_TRUE(torch::allclose(y, y_exp));
1508 }
1509 
TEST_F(FunctionalTest,Bilinear)1510 TEST_F(FunctionalTest, Bilinear) {
1511   auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}});
1512   auto input2 = torch::tensor({{7, 4}, {8, 9}});
1513   auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}});
1514   auto bias = torch::tensor({1});
1515 
1516   auto y_with_bias = F::bilinear(input1, input2, weight, bias);
1517   ASSERT_EQ(y_with_bias.ndimension(), 2);
1518   ASSERT_EQ(y_with_bias.sizes(), torch::IntArrayRef({2, 1}));
1519   auto y_with_bias_exp = torch::tensor({{449}, {1702}}).reshape({2, 1});
1520   ASSERT_TRUE(torch::allclose(y_with_bias, y_with_bias_exp, 1e-4, 1e-7));
1521 
1522   auto y_no_bias = F::bilinear(input1, input2, weight);
1523   ASSERT_EQ(y_no_bias.ndimension(), 2);
1524   ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1}));
1525   auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1});
1526   ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7));
1527 
1528   input1 = input1.to(torch::kFloat64);
1529   input2 = input2.to(torch::kInt32);
1530   weight = weight.to(torch::kInt32);
1531   ASSERT_THROWS_WITH(
1532       F::bilinear(input1, input2, weight),
1533       "All tensors must have the same dtype, got input1: double, input2: int, weight: int");
1534 }
1535 
TEST_F(FunctionalTest,Normalize)1536 TEST_F(FunctionalTest, Normalize) {
1537   const auto expected = torch::tensor(
1538       {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000},
1539         {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}},
1540       torch::requires_grad().dtype(torch::kFloat));
1541   { // Test #1
1542     auto input = torch::tensor(
1543         {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}},
1544         torch::dtype(torch::kFloat).requires_grad(true));
1545     auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1546 
1547     // reduce to scalar to call .backward()
1548     torch::Tensor s = norm.sum();
1549     s.backward();
1550 
1551     ASSERT_EQ(s.ndimension(), 0);
1552     ASSERT_EQ(input.grad().numel(), 10);
1553     ASSERT_TRUE(torch::allclose(norm, expected));
1554   }
1555 
1556   { // Test #2 Check variations of optional arguments
1557     auto input = torch::tensor(
1558         {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat));
1559     auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat));
1560     // non-null output argument
1561     F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output));
1562     // default options
1563     F::normalize(input);
1564 
1565     ASSERT_TRUE(torch::allclose(output, expected));
1566   }
1567 
1568   { // Test #3 Base case of scalar tensor
1569     auto input = torch::randn({}, torch::requires_grad());
1570     torch::Tensor norm =
1571         F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1572     norm.backward();
1573 
1574     ASSERT_EQ(input.grad().numel(), 1);
1575   }
1576 }
1577 
TEST_F(FunctionalTest,ReLU)1578 TEST_F(FunctionalTest, ReLU) {
1579   const auto size = 3;
1580   for (const auto inplace : {false, true}) {
1581     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1582     x.resize_({size, size, size});
1583     auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1584     auto y = F::relu(x, F::ReLUFuncOptions().inplace(inplace));
1585 
1586     ASSERT_EQ(y.ndimension(), 3);
1587     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1588     ASSERT_TRUE(torch::allclose(y, y_exp));
1589     if (inplace) {
1590       ASSERT_TRUE(torch::allclose(x, y_exp));
1591     }
1592 
1593     // NOLINTNEXTLINE(bugprone-argument-comment)
1594     y = F::relu(x, /*inplace=*/inplace);
1595 
1596     ASSERT_EQ(y.ndimension(), 3);
1597     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1598     ASSERT_TRUE(torch::allclose(y, y_exp));
1599     if (inplace) {
1600       ASSERT_TRUE(torch::allclose(x, y_exp));
1601     }
1602   }
1603   ASSERT_TRUE(F::relu(torch::tensor(1.)).defined());
1604 }
1605 
TEST_F(FunctionalTest,ReLUDefaultOptions)1606 TEST_F(FunctionalTest, ReLUDefaultOptions) {
1607   const auto size = 3;
1608   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1609   x.resize_({size, size, size});
1610   auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1611   auto y = F::relu(x);
1612 
1613   ASSERT_EQ(y.ndimension(), 3);
1614   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1615   ASSERT_TRUE(torch::allclose(y, y_exp));
1616 }
1617 
TEST_F(FunctionalTest,ReLU6)1618 TEST_F(FunctionalTest, ReLU6) {
1619   const auto size = 3;
1620   for (const auto inplace : {false, true}) {
1621     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1622     x.resize_({size, size, size});
1623     auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1624     auto y = F::relu6(x, F::ReLU6FuncOptions().inplace(inplace));
1625 
1626     ASSERT_EQ(y.ndimension(), 3);
1627     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1628     ASSERT_TRUE(torch::allclose(y, y_exp));
1629     if (inplace) {
1630       ASSERT_TRUE(torch::allclose(x, y_exp));
1631     }
1632 
1633     // NOLINTNEXTLINE(bugprone-argument-comment)
1634     y = F::relu6(x, /*inplace=*/inplace);
1635 
1636     ASSERT_EQ(y.ndimension(), 3);
1637     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1638     ASSERT_TRUE(torch::allclose(y, y_exp));
1639     if (inplace) {
1640       ASSERT_TRUE(torch::allclose(x, y_exp));
1641     }
1642   }
1643   ASSERT_TRUE(F::relu6(torch::tensor(1.)).defined());
1644 }
1645 
TEST_F(FunctionalTest,ReLU6DefaultOptions)1646 TEST_F(FunctionalTest, ReLU6DefaultOptions) {
1647   const auto size = 3;
1648   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1649   x.resize_({size, size, size});
1650   auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1651   auto y = F::relu6(x);
1652 
1653   ASSERT_EQ(y.ndimension(), 3);
1654   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1655   ASSERT_TRUE(torch::allclose(y, y_exp));
1656 }
1657 
TEST_F(FunctionalTest,RReLU)1658 TEST_F(FunctionalTest, RReLU) {
1659   const auto size = 3;
1660   for (const auto lower : {0.01, 0.1, 0.2}) {
1661     for (const auto upper : {0.3, 0.4, 0.5}) {
1662       for (const auto inplace : {false, true}) {
1663         for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1664           auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1665           x.resize_({size, size, size});
1666           auto x_copy = x.clone();
1667           auto y = F::rrelu(
1668               x,
1669               F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace));
1670           auto z =
1671               ((x_copy >= 0) * (x_copy == y) +
1672                (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1673               1.0;
1674 
1675           ASSERT_EQ(y.ndimension(), 3);
1676           ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1677           ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1678           if (inplace) {
1679             ASSERT_TRUE(torch::allclose(x, y));
1680           }
1681         }
1682       }
1683     }
1684   }
1685   ASSERT_TRUE(F::rrelu(torch::tensor(1.)).defined());
1686 }
1687 
TEST_F(FunctionalTest,RReLUDefaultOptions)1688 TEST_F(FunctionalTest, RReLUDefaultOptions) {
1689   const auto size = 3;
1690   const auto lower = 1.0 / 8.0;
1691   const auto upper = 1.0 / 3.0;
1692   for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1693     auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1694     x.resize_({size, size, size});
1695     auto x_copy = x.clone();
1696     auto y = F::rrelu(x);
1697     auto z = ((x_copy >= 0) * (x_copy == y) +
1698               (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1699         1.0;
1700 
1701     ASSERT_EQ(y.ndimension(), 3);
1702     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1703     ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1704   }
1705 }
1706 
TEST_F(FunctionalTest,CELU)1707 TEST_F(FunctionalTest, CELU) {
1708   const auto size = 3;
1709   for (const auto inplace : {false, true}) {
1710     for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
1711       auto x = torch::linspace(-10.0, 10.0, size * size * size);
1712       x.resize_({size, size, size});
1713       auto x_bf16 = x.clone().to(torch::kBFloat16);
1714       auto y_exp = torch::max(torch::zeros_like(x), x) +
1715           torch::min(torch::zeros_like(x),
1716                      alpha * (torch::exp(x / alpha) - 1.0));
1717       auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1718       auto y_bf16 =
1719           F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1720 
1721       ASSERT_EQ(y.ndimension(), 3);
1722       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1723       ASSERT_TRUE(torch::allclose(y, y_exp));
1724       ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1725       if (inplace) {
1726         ASSERT_TRUE(torch::allclose(x, y_exp));
1727         ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1728       }
1729     }
1730   }
1731   ASSERT_TRUE(F::celu(torch::tensor(1.)).defined());
1732 }
1733 
TEST_F(FunctionalTest,CELUDefaultOptions)1734 TEST_F(FunctionalTest, CELUDefaultOptions) {
1735   const auto size = 3;
1736   const auto alpha = 1.0;
1737   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1738   x.resize_({size, size, size});
1739   auto x_bf16 = x.clone().to(torch::kBFloat16);
1740   auto y_exp = torch::max(torch::zeros_like(x), x) +
1741       torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
1742   auto y = F::celu(x);
1743   auto y_bf16 = F::celu(x_bf16);
1744 
1745   ASSERT_EQ(y.ndimension(), 3);
1746   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1747   ASSERT_TRUE(torch::allclose(y, y_exp));
1748   ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1749 }
1750 
TEST_F(FunctionalTest,PixelShuffle)1751 TEST_F(FunctionalTest, PixelShuffle) {
1752   auto x = torch::tensor(
1753       {{{{-17, 19}, {-1, 2}},
1754         {{7, 14}, {-3, 1}},
1755         {{0, -2}, {-12, 14}},
1756         {{-15, 0}, {-3, 9}}}},
1757       torch::kFloat);
1758   auto y_exp = torch::tensor(
1759       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1760       torch::kFloat);
1761   auto y = F::pixel_shuffle(x, 2);
1762 
1763   ASSERT_EQ(y.ndimension(), 4);
1764   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
1765   ASSERT_TRUE(y.allclose(y_exp));
1766 }
1767 
TEST_F(FunctionalTest,PixelUnshuffle)1768 TEST_F(FunctionalTest, PixelUnshuffle) {
1769   auto x = torch::tensor(
1770       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1771       torch::kFloat);
1772   auto y_exp = torch::tensor(
1773       {{{{-17, 19}, {-1, 2}},
1774         {{7, 14}, {-3, 1}},
1775         {{0, -2}, {-12, 14}},
1776         {{-15, 0}, {-3, 9}}}},
1777       torch::kFloat);
1778   auto y = F::pixel_unshuffle(x, 2);
1779 
1780   ASSERT_EQ(y.ndimension(), 4);
1781   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
1782   ASSERT_TRUE(y.allclose(y_exp));
1783 }
1784 
TEST_F(FunctionalTest,Softplus)1785 TEST_F(FunctionalTest, Softplus) {
1786   const auto size = 3;
1787   for (const auto beta : {0.5, 1.0, 2.0}) {
1788     for (const auto threshold : {1.0, 3.0, 5.0}) {
1789       auto x = torch::linspace(-3.0, 3.0, 61);
1790       x.resize_({size, size, size});
1791       auto y_exp =
1792           (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1793           (x > threshold) * x;
1794       auto y = F::softplus(
1795           x, F::SoftplusFuncOptions().beta(beta).threshold(threshold));
1796 
1797       ASSERT_EQ(y.ndimension(), 3);
1798       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1799       ASSERT_TRUE(torch::allclose(y, y_exp));
1800     }
1801   }
1802 }
1803 
TEST_F(FunctionalTest,SoftplusDefaultOptions)1804 TEST_F(FunctionalTest, SoftplusDefaultOptions) {
1805   const auto size = 3;
1806   const auto beta = 1.0;
1807   const auto threshold = 20.0;
1808   auto x = torch::linspace(-3.0, 3.0, 61);
1809   x.resize_({size, size, size});
1810   auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1811       (x > threshold) * x;
1812   auto y = F::softplus(x);
1813 
1814   ASSERT_EQ(y.ndimension(), 3);
1815   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1816   ASSERT_TRUE(torch::allclose(y, y_exp));
1817 }
1818 
TEST_F(FunctionalTest,Fold)1819 TEST_F(FunctionalTest, Fold) {
1820   auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::kDouble);
1821   auto output = F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
1822   auto expected = torch::tensor(
1823       {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1824         {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1825         {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1826       torch::kDouble);
1827 
1828   ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1829   ASSERT_TRUE(output.allclose(expected));
1830 }
1831 
TEST_F(FunctionalTest,Unfold)1832 TEST_F(FunctionalTest, Unfold) {
1833   auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3});
1834   auto output =
1835       F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
1836   auto expected = torch::tensor(
1837       {{{0.0, 0.0, 0.0, 4.0},
1838         {0.0, 0.0, 3.0, 5.0},
1839         {0.0, 1.0, 0.0, 0.0},
1840         {0.0, 2.0, 0.0, 0.0},
1841         {0.0, 0.0, 0.0, 10.0},
1842         {0.0, 0.0, 9.0, 11.0},
1843         {0.0, 7.0, 0.0, 0.0},
1844         {6.0, 8.0, 0.0, 0.0}}},
1845       torch::kDouble);
1846 
1847   ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1848   ASSERT_TRUE(output.allclose(expected));
1849 }
1850 
TEST_F(FunctionalTest,Softshrink)1851 TEST_F(FunctionalTest, Softshrink) {
1852   const auto size = 3;
1853   for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1854     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1855     x.resize_({size, size, size}).set_requires_grad(true);
1856     // NOLINTNEXTLINE(bugprone-argument-comment)
1857     auto y = F::softshrink(x, /*lambda=*/lambda);
1858     torch::Tensor s = y.sum();
1859 
1860     s.backward();
1861     ASSERT_EQ(s.ndimension(), 0);
1862 
1863     ASSERT_EQ(y.ndimension(), 3);
1864     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1865     auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1866     ASSERT_TRUE(torch::allclose(y, y_exp));
1867   }
1868 }
1869 
TEST_F(FunctionalTest,SoftshrinkDefaultOptions)1870 TEST_F(FunctionalTest, SoftshrinkDefaultOptions) {
1871   const auto size = 3;
1872   const auto lambda = 0.5;
1873   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1874   x.resize_({size, size, size}).set_requires_grad(true);
1875   auto y = F::softshrink(x);
1876   torch::Tensor s = y.sum();
1877 
1878   s.backward();
1879   ASSERT_EQ(s.ndimension(), 0);
1880 
1881   ASSERT_EQ(y.ndimension(), 3);
1882   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1883   auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1884 }
1885 
TEST_F(FunctionalTest,Softsign)1886 TEST_F(FunctionalTest, Softsign) {
1887   auto x = torch::randn(100) * 10;
1888   auto y_exp = x / (1 + x.abs());
1889   auto y = F::softsign(x);
1890 
1891   ASSERT_TRUE(torch::allclose(y, y_exp));
1892 }
1893 
TEST_F(FunctionalTest,Mish)1894 TEST_F(FunctionalTest, Mish) {
1895   auto x = torch::randn(100) * 10;
1896   auto y_exp = x * x.exp().log1p().tanh();
1897   auto y = F::mish(x);
1898 
1899   ASSERT_TRUE(torch::allclose(y, y_exp));
1900 }
1901 
TEST_F(FunctionalTest,Tanhshrink)1902 TEST_F(FunctionalTest, Tanhshrink) {
1903   auto x = torch::randn(100) * 10;
1904   auto y_exp = x - x.tanh();
1905   auto y = F::tanhshrink(x);
1906 
1907   ASSERT_TRUE(torch::allclose(y, y_exp));
1908 }
1909 
TEST_F(FunctionalTest,Threshold)1910 TEST_F(FunctionalTest, Threshold) {
1911   const auto size = 3;
1912   for (const auto threshold : {0.5, 1.0, 2.0}) {
1913     for (const auto value : {0.5, 1.0, 2.0}) {
1914       for (const auto inplace : {false, true}) {
1915         auto x = torch::linspace(-3.0, 3.0, 61);
1916         x.resize_({size, size, size});
1917         auto y_exp = (x <= threshold) * value + (x > threshold) * x;
1918         auto y = F::threshold(
1919             x, F::ThresholdFuncOptions(threshold, value).inplace(inplace));
1920 
1921         ASSERT_EQ(y.ndimension(), 3);
1922         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1923         ASSERT_TRUE(torch::allclose(y, y_exp));
1924         if (inplace) {
1925           ASSERT_TRUE(torch::allclose(x, y_exp));
1926         }
1927       }
1928     }
1929   }
1930   ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5))
1931                   .defined());
1932 }
1933 
TEST_F(FunctionalTest,BatchNorm1d)1934 TEST_F(FunctionalTest, BatchNorm1d) {
1935   int num_features = 5;
1936   double eps = 1e-05;
1937   double momentum = 0.1;
1938 
1939   auto input = torch::randn({2, 5});
1940   auto mean = torch::randn(5);
1941   auto variance = torch::rand(5);
1942   auto weight = torch::ones({num_features});
1943   auto bias = torch::zeros({num_features});
1944   auto output = F::batch_norm(
1945       input,
1946       mean,
1947       variance,
1948       F::BatchNormFuncOptions()
1949           .weight(weight)
1950           .bias(bias)
1951           .momentum(momentum)
1952           .eps(eps)
1953           .training(false));
1954   auto expected = (input - mean) / torch::sqrt(variance + eps);
1955   ASSERT_TRUE(output.allclose(expected));
1956 }
1957 
TEST_F(FunctionalTest,BatchNorm1dDefaultOptions)1958 TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
1959   auto input = torch::randn({2, 5});
1960   auto mean = torch::randn(5);
1961   auto variance = torch::rand(5);
1962   auto output = F::batch_norm(input, mean, variance);
1963   auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
1964   ASSERT_TRUE(output.allclose(expected));
1965 }
1966 
TEST_F(FunctionalTest,BatchNorm2d)1967 TEST_F(FunctionalTest, BatchNorm2d) {
1968   int num_features = 5;
1969   double eps = 1e-05;
1970   double momentum = 0.1;
1971 
1972   auto input = torch::randn({2, num_features, 4, 4});
1973   auto mean = torch::randn(num_features);
1974   auto variance = torch::rand(num_features);
1975   auto weight = torch::ones({num_features});
1976   auto bias = torch::zeros({num_features});
1977   auto output = F::batch_norm(
1978       input,
1979       mean,
1980       variance,
1981       F::BatchNormFuncOptions()
1982           .weight(weight)
1983           .bias(bias)
1984           .momentum(momentum)
1985           .eps(eps)
1986           .training(false));
1987   auto expected = torch::transpose(
1988       (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
1989       1,
1990       3);
1991   ASSERT_TRUE(output.allclose(expected));
1992 }
1993 
TEST_F(FunctionalTest,BatchNorm2dDefaultOptions)1994 TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) {
1995   int num_features = 5;
1996   double eps = 1e-05;
1997 
1998   auto input = torch::randn({2, num_features, 4, 4});
1999   auto mean = torch::randn(num_features);
2000   auto variance = torch::rand(num_features);
2001   auto output = F::batch_norm(input, mean, variance);
2002   auto expected = torch::transpose(
2003       (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
2004       1,
2005       3);
2006   ASSERT_TRUE(output.allclose(expected));
2007 }
2008 
TEST_F(FunctionalTest,BatchNorm3d)2009 TEST_F(FunctionalTest, BatchNorm3d) {
2010   int num_features = 5;
2011   double eps = 1e-05;
2012   double momentum = 0.1;
2013 
2014   auto input = torch::randn({2, num_features, 2, 2, 2});
2015   auto mean = torch::randn(num_features);
2016   auto variance = torch::rand(num_features);
2017   auto weight = torch::ones({num_features});
2018   auto bias = torch::zeros({num_features});
2019   auto output = F::batch_norm(
2020       input,
2021       mean,
2022       variance,
2023       F::BatchNormFuncOptions()
2024           .weight(weight)
2025           .bias(bias)
2026           .momentum(momentum)
2027           .eps(eps)
2028           .training(false));
2029   auto expected = torch::transpose(
2030       (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2031       1,
2032       4);
2033   ASSERT_TRUE(output.allclose(expected));
2034 }
2035 
TEST_F(FunctionalTest,BatchNorm3dDefaultOptions)2036 TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
2037   int num_features = 5;
2038   double eps = 1e-05;
2039 
2040   auto input = torch::randn({2, num_features, 2, 2, 2});
2041   auto mean = torch::randn(num_features);
2042   auto variance = torch::rand(num_features);
2043   auto output = F::batch_norm(input, mean, variance);
2044   auto expected = torch::transpose(
2045       (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2046       1,
2047       4);
2048   ASSERT_TRUE(output.allclose(expected));
2049 }
2050 
TEST_F(FunctionalTest,InstanceNorm1d)2051 TEST_F(FunctionalTest, InstanceNorm1d) {
2052   int num_features = 5;
2053   double eps = 1e-05;
2054   double momentum = 0.1;
2055 
2056   auto input = torch::arange(40.).view({2, 5, 4});
2057   auto mean = torch::arange(5.);
2058   auto variance = torch::arange(5.);
2059   auto weight = torch::arange((double)num_features);
2060   auto bias = torch::arange((double)num_features);
2061   auto output = F::instance_norm(
2062       input,
2063       F::InstanceNormFuncOptions()
2064           .running_mean(mean)
2065           .running_var(variance)
2066           .weight(weight)
2067           .bias(bias)
2068           .momentum(momentum)
2069           .eps(eps));
2070   auto expected = torch::tensor(
2071       {{{0.0000, 0.0000, 0.0000, 0.0000},
2072         {-0.3416, 0.5528, 1.4472, 2.3416},
2073         {-0.6833, 1.1056, 2.8944, 4.6833},
2074         {-1.0249, 1.6584, 4.3416, 7.0249},
2075         {-1.3665, 2.2112, 5.7888, 9.3665}},
2076        {{0.0000, 0.0000, 0.0000, 0.0000},
2077         {-0.3416, 0.5528, 1.4472, 2.3416},
2078         {-0.6833, 1.1056, 2.8944, 4.6833},
2079         {-1.0249, 1.6584, 4.3416, 7.0249},
2080         {-1.3665, 2.2112, 5.7888, 9.3665}}});
2081   ASSERT_TRUE(output.allclose(expected, 2e-04));
2082 }
2083 
TEST_F(FunctionalTest,InstanceNorm1dDefaultOptions)2084 TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
2085   auto input = torch::arange(40.).view({2, 5, 4});
2086   auto output = F::instance_norm(input);
2087   auto expected = torch::tensor(
2088       {{{-1.3416, -0.4472, 0.4472, 1.3416},
2089         {-1.3416, -0.4472, 0.4472, 1.3416},
2090         {-1.3416, -0.4472, 0.4472, 1.3416},
2091         {-1.3416, -0.4472, 0.4472, 1.3416},
2092         {-1.3416, -0.4472, 0.4472, 1.3416}},
2093        {{-1.3416, -0.4472, 0.4472, 1.3416},
2094         {-1.3416, -0.4472, 0.4472, 1.3416},
2095         {-1.3416, -0.4472, 0.4472, 1.3416},
2096         {-1.3416, -0.4472, 0.4472, 1.3416},
2097         {-1.3416, -0.4472, 0.4472, 1.3416}}});
2098   ASSERT_TRUE(output.allclose(expected, 2e-04));
2099 }
2100 
TEST_F(FunctionalTest,InstanceNorm2d)2101 TEST_F(FunctionalTest, InstanceNorm2d) {
2102   int num_features = 5;
2103   double eps = 1e-05;
2104   double momentum = 0.1;
2105 
2106   auto input =
2107       torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2108   auto mean = torch::arange((double)num_features);
2109   auto variance = torch::arange((double)num_features);
2110   auto weight = torch::arange((double)num_features);
2111   auto bias = torch::arange((double)num_features);
2112   auto output = F::instance_norm(
2113       input,
2114       F::InstanceNormFuncOptions()
2115           .running_mean(mean)
2116           .running_var(variance)
2117           .weight(weight)
2118           .bias(bias)
2119           .momentum(momentum)
2120           .eps(eps));
2121   auto expected = torch::tensor(
2122       {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2123         {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2124         {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2125         {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2126         {{-1.3665, 2.2112}, {5.7888, 9.3665}}},
2127        {{{0.0000, 0.0000}, {0.0000, 0.0000}},
2128         {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2129         {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2130         {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2131         {{-1.3665, 2.2112}, {5.7888, 9.3665}}}});
2132   ASSERT_TRUE(output.allclose(expected, 2e-04));
2133 }
2134 
TEST_F(FunctionalTest,InstanceNorm2dDefaultOptions)2135 TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
2136   int num_features = 5;
2137 
2138   auto input =
2139       torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2140   auto output = F::instance_norm(input);
2141   auto expected = torch::tensor(
2142       {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2143         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2144         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2145         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2146         {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
2147        {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2148         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2149         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2150         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2151         {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
2152   ASSERT_TRUE(output.allclose(expected, 2e-04));
2153 }
2154 
TEST_F(FunctionalTest,InstanceNorm3d)2155 TEST_F(FunctionalTest, InstanceNorm3d) {
2156   int num_features = 5;
2157   double eps = 1e-05;
2158   double momentum = 0.1;
2159 
2160   auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2161                    .view({2, num_features, 2, 2, 2});
2162   auto mean = torch::arange((double)num_features);
2163   auto variance = torch::arange((double)num_features);
2164   auto weight = torch::arange((double)num_features);
2165   auto bias = torch::arange((double)num_features);
2166   auto output = F::instance_norm(
2167       input,
2168       F::InstanceNormFuncOptions()
2169           .running_mean(mean)
2170           .running_var(variance)
2171           .weight(weight)
2172           .bias(bias)
2173           .momentum(momentum)
2174           .eps(eps));
2175   auto expected = torch::tensor(
2176       {{{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2177          {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2178         {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2179          {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2180         {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2181          {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2182         {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2183          {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2184         {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2185          {{4.8729, 6.6186}, {8.3644, 10.1101}}}},
2186        {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2187          {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2188         {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2189          {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2190         {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2191          {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2192         {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2193          {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2194         {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2195          {{4.8729, 6.6186}, {8.3644, 10.1101}}}}});
2196   ASSERT_TRUE(output.allclose(expected, 2e-04));
2197 }
2198 
TEST_F(FunctionalTest,InstanceNorm3dDefaultOptions)2199 TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
2200   int num_features = 5;
2201 
2202   auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2203                    .view({2, num_features, 2, 2, 2});
2204   auto output = F::instance_norm(input);
2205   auto expected = torch::tensor(
2206       {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2207          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2208         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2209          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2210         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2211          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2212         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2213          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2214         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2215          {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
2216        {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2217          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2218         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2219          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2220         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2221          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2222         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2223          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2224         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2225          {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
2226   ASSERT_TRUE(output.allclose(expected, 2e-04));
2227 }
2228 
TEST_F(FunctionalTest,Interpolate)2229 TEST_F(FunctionalTest, Interpolate) {
2230   {
2231     // 1D interpolation
2232     auto input = torch::ones({1, 1, 2});
2233     auto options = F::InterpolateFuncOptions()
2234                        .size(std::vector<int64_t>({4}))
2235                        .mode(torch::kNearest);
2236     auto output = F::interpolate(input, options);
2237     auto expected = torch::ones({1, 1, 4});
2238 
2239     ASSERT_TRUE(output.allclose(expected));
2240   }
2241   {
2242     // 2D interpolation
2243     for (const auto align_corners : {true, false}) {
2244       // test float scale factor up & down sampling
2245       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2246         auto input = torch::ones({1, 1, 2, 2});
2247         auto options =
2248             F::InterpolateFuncOptions()
2249                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
2250                 .mode(torch::kBilinear)
2251                 .align_corners(align_corners);
2252         auto output = F::interpolate(input, options);
2253         auto expected_size =
2254             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2255         auto expected = torch::ones({1, 1, expected_size, expected_size});
2256 
2257         ASSERT_TRUE(output.allclose(expected));
2258       }
2259     }
2260   }
2261   {
2262     // 3D interpolation
2263     for (const auto align_corners : {true, false}) {
2264       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2265         auto input = torch::ones({1, 1, 2, 2, 2});
2266         auto options = F::InterpolateFuncOptions()
2267                            .scale_factor(std::vector<double>(
2268                                {scale_factor, scale_factor, scale_factor}))
2269                            .mode(torch::kTrilinear)
2270                            .align_corners(align_corners);
2271         auto output = F::interpolate(input, options);
2272         auto expected_size =
2273             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2274         auto expected =
2275             torch::ones({1, 1, expected_size, expected_size, expected_size});
2276 
2277         ASSERT_TRUE(output.allclose(expected));
2278       }
2279     }
2280   }
2281   {
2282     ASSERT_THROWS_WITH(
2283         F::interpolate(
2284             torch::randn({1}),
2285             F::InterpolateFuncOptions().size(std::vector<int64_t>({1}))),
2286         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 1D) ");
2287   }
2288   {
2289     auto input = torch::randn({3, 2, 2});
2290     ASSERT_THROWS_WITH(
2291         F::interpolate(
2292             input[0],
2293             F::InterpolateFuncOptions().size(std::vector<int64_t>({4, 4}))),
2294         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) "
2295         "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2296     ASSERT_THROWS_WITH(
2297         F::interpolate(
2298             torch::reshape(input, {1, 1, 1, 3, 2, 2}),
2299             F::InterpolateFuncOptions().size(
2300                 std::vector<int64_t>({1, 1, 1, 3, 4, 4}))),
2301         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) "
2302         "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2303     ASSERT_THROWS_WITH(
2304         F::interpolate(input, F::InterpolateFuncOptions()),
2305         "either size or scale_factor should be defined");
2306     ASSERT_THROWS_WITH(
2307         F::interpolate(
2308             input,
2309             F::InterpolateFuncOptions()
2310                 .size(std::vector<int64_t>({3, 4, 4}))
2311                 .scale_factor(std::vector<double>({0.5}))),
2312         "only one of size or scale_factor should be defined");
2313     ASSERT_THROWS_WITH(
2314         F::interpolate(
2315             input,
2316             F::InterpolateFuncOptions().scale_factor(
2317                 std::vector<double>({3, 2}))),
2318         "scale_factor shape must match input shape. "
2319         "Input is 1D, scale_factor size is [3, 2]");
2320     ASSERT_THROWS_WITH(
2321         F::interpolate(
2322             input,
2323             F::InterpolateFuncOptions()
2324                 .mode(torch::kNearest)
2325                 .align_corners(true)),
2326         "align_corners option can only be set with the "
2327         "interpolating modes: linear | bilinear | bicubic | trilinear");
2328   }
2329   {
2330     auto tensor = torch::rand({2, 3, 32, 32});
2331     std::vector<int64_t> osize = {8, 10};
2332     auto expected =
2333         at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt);
2334 
2335     auto options = F::InterpolateFuncOptions()
2336                        .size(osize)
2337                        .mode(torch::kNearestExact)
2338                        .align_corners(false);
2339     auto output = F::interpolate(tensor, options);
2340 
2341     ASSERT_TRUE(output.allclose(expected));
2342   }
2343   {
2344     auto tensor = torch::rand({2, 3, 32, 32});
2345     std::vector<int64_t> osize = {8, 10};
2346     auto expected = at::native::_upsample_bilinear2d_aa(
2347         tensor, osize, false, torch::nullopt);
2348 
2349     auto options = F::InterpolateFuncOptions()
2350                        .size(osize)
2351                        .mode(torch::kBilinear)
2352                        .align_corners(false)
2353                        .antialias(true);
2354     auto output = F::interpolate(tensor, options);
2355     ASSERT_TRUE(output.allclose(expected));
2356   }
2357   {
2358     auto tensor = torch::rand({2, 3, 32, 32});
2359     std::vector<int64_t> osize = {8, 10};
2360     auto expected = at::native::_upsample_bicubic2d_aa(
2361         tensor, osize, false, torch::nullopt);
2362 
2363     auto options = F::InterpolateFuncOptions()
2364                        .size(osize)
2365                        .mode(torch::kBicubic)
2366                        .align_corners(false)
2367                        .antialias(true);
2368     auto output = F::interpolate(tensor, options);
2369     ASSERT_TRUE(output.allclose(expected));
2370   }
2371 }
2372 
TEST_F(FunctionalTest,Pad1)2373 TEST_F(FunctionalTest, Pad1) {
2374   {
2375     auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
2376     auto output =
2377         F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular));
2378     auto expected = torch::tensor(
2379         {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
2380     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
2381     ASSERT_TRUE(output.allclose(expected, 1e-04));
2382   }
2383 }
TEST_F(FunctionalTest,Pad2)2384 TEST_F(FunctionalTest, Pad2) {
2385   {
2386     auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
2387     auto output =
2388         F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular));
2389     auto expected = torch::tensor(
2390         {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2391            {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2392            {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2393            {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2394            {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2395            {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2396            {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}},
2397         torch::kDouble);
2398     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
2399     ASSERT_TRUE(output.allclose(expected, 1e-04));
2400   }
2401 }
TEST_F(FunctionalTest,Pad3)2402 TEST_F(FunctionalTest, Pad3) {
2403   {
2404     auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2405     auto output = F::pad(
2406         input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
2407     auto expected = torch::tensor(
2408         {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2409             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2410             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2411             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2412             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2413 
2414            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2415             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2416             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2417             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2418             {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2419 
2420            {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2421             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2422             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2423             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2424             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2425 
2426            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2427             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2428             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2429             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2430             {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2431 
2432            {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2433             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2434             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2435             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2436             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2437 
2438            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2439             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2440             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2441             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2442             {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}},
2443         torch::kDouble);
2444     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
2445     ASSERT_TRUE(output.allclose(expected, 1e-04));
2446   }
2447 }
TEST_F(FunctionalTest,Pad4)2448 TEST_F(FunctionalTest, Pad4) {
2449   {
2450     auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
2451     auto output =
2452         F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect));
2453     auto expected = torch::tensor(
2454         {{{{3., 2., 3., 2.},
2455            {1., 0., 1., 0.},
2456            {3., 2., 3., 2.},
2457            {1., 0., 1., 0.}},
2458 
2459           {{7., 6., 7., 6.},
2460            {5., 4., 5., 4.},
2461            {7., 6., 7., 6.},
2462            {5., 4., 5., 4.}}},
2463 
2464          {{{11., 10., 11., 10.},
2465            {9., 8., 9., 8.},
2466            {11., 10., 11., 10.},
2467            {9., 8., 9., 8.}},
2468 
2469           {{15., 14., 15., 14.},
2470            {13., 12., 13., 12.},
2471            {15., 14., 15., 14.},
2472            {13., 12., 13., 12.}}}},
2473         torch::kDouble);
2474     ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
2475     ASSERT_TRUE(output.allclose(expected, 1e-04));
2476   }
2477 }
TEST_F(FunctionalTest,Pad5)2478 TEST_F(FunctionalTest, Pad5) {
2479   {
2480     auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2481     auto output = F::pad(
2482         input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
2483     auto expected = torch::tensor(
2484         {{{{{0., 0., 1., 2., 2., 2.},
2485             {0., 0., 1., 2., 2., 2.},
2486             {0., 0., 1., 2., 2., 2.},
2487             {3., 3., 4., 5., 5., 5.},
2488             {3., 3., 4., 5., 5., 5.}},
2489 
2490            {{0., 0., 1., 2., 2., 2.},
2491             {0., 0., 1., 2., 2., 2.},
2492             {0., 0., 1., 2., 2., 2.},
2493             {3., 3., 4., 5., 5., 5.},
2494             {3., 3., 4., 5., 5., 5.}},
2495 
2496            {{6., 6., 7., 8., 8., 8.},
2497             {6., 6., 7., 8., 8., 8.},
2498             {6., 6., 7., 8., 8., 8.},
2499             {9., 9., 10., 11., 11., 11.},
2500             {9., 9., 10., 11., 11., 11.}},
2501 
2502            {{6., 6., 7., 8., 8., 8.},
2503             {6., 6., 7., 8., 8., 8.},
2504             {6., 6., 7., 8., 8., 8.},
2505             {9., 9., 10., 11., 11., 11.},
2506             {9., 9., 10., 11., 11., 11.}},
2507 
2508            {{6., 6., 7., 8., 8., 8.},
2509             {6., 6., 7., 8., 8., 8.},
2510             {6., 6., 7., 8., 8., 8.},
2511             {9., 9., 10., 11., 11., 11.},
2512             {9., 9., 10., 11., 11., 11.}}}}},
2513         torch::kDouble);
2514     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
2515     ASSERT_TRUE(output.allclose(expected, 1e-04));
2516   }
2517 }
TEST_F(FunctionalTest,Pad6)2518 TEST_F(FunctionalTest, Pad6) {
2519   {
2520     auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3});
2521     auto output = F::pad(
2522         input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect));
2523     auto expected = torch::tensor(
2524         {{{{{9., 10., 11., 10., 9.},
2525             {6., 7., 8., 7., 6.},
2526             {9., 10., 11., 10., 9.}},
2527 
2528            {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}},
2529 
2530            {{9., 10., 11., 10., 9.},
2531             {6., 7., 8., 7., 6.},
2532             {9., 10., 11., 10., 9.}},
2533 
2534            {{15., 16., 17., 16., 15.},
2535             {12., 13., 14., 13., 12.},
2536             {15., 16., 17., 16., 15.}},
2537 
2538            {{9., 10., 11., 10., 9.},
2539             {6., 7., 8., 7., 6.},
2540             {9., 10., 11., 10., 9.}},
2541 
2542            {{3., 4., 5., 4., 3.},
2543             {0., 1., 2., 1., 0.},
2544             {3., 4., 5., 4., 3.}}}}},
2545         torch::kDouble);
2546     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 3, 5}));
2547     ASSERT_TRUE(output.allclose(expected, 1e-04));
2548   }
2549 }
TEST_F(FunctionalTest,Pad7)2550 TEST_F(FunctionalTest, Pad7) {
2551   {
2552     auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2553     auto output = F::pad(
2554         input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0));
2555     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2556     auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2557   }
2558 }
TEST_F(FunctionalTest,Pad8)2559 TEST_F(FunctionalTest, Pad8) {
2560   {
2561     auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2562     auto output = F::pad(input, F::PadFuncOptions({1, 1}));
2563     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2564     auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2565   }
2566 }
2567 
TEST_F(FunctionalTest,CTCLoss)2568 TEST_F(FunctionalTest, CTCLoss) {
2569   { // test CTCLoss typechecks
2570     const auto target_lengths = torch::tensor({30, 25, 20});
2571     const auto input_lengths = torch::tensor({50, 50, 50});
2572     const auto targets =
2573         torch::randint(1, 15, {target_lengths.sum().item<int>()}, torch::kInt);
2574     const auto log_probs =
2575         torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2576 
2577     const auto _input_lengths = input_lengths.to(torch::kFloat);
2578     ASSERT_THROWS_WITH(
2579         F::ctc_loss(log_probs, targets, _input_lengths, target_lengths),
2580         "input_lengths must be integral");
2581 
2582     const auto target_lengths_ = target_lengths.to(torch::kFloat);
2583     ASSERT_THROWS_WITH(
2584         F::ctc_loss(log_probs, targets, input_lengths, target_lengths_),
2585         "target_lengths must be integral");
2586   }
2587   { // test CTCLoss length checks
2588     const auto target_lengths = torch::tensor({30, 25, 20});
2589     const auto input_lengths = torch::tensor({50, 50, 50});
2590     const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt);
2591     const auto log_probs =
2592         torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2593     ASSERT_THROWS_WITH(
2594         F::ctc_loss(log_probs, targets, input_lengths, target_lengths),
2595         "Expected tensor to have size at least 30 at dimension 1");
2596   }
2597   { // test CTCLoss empty target
2598     {
2599       const auto target_lengths = torch::tensor({0, 0, 0});
2600       const auto input_lengths = torch::tensor({50, 50, 50});
2601       const auto targets =
2602           torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
2603       const auto log_probs =
2604           torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2605       const auto loss = F::ctc_loss(
2606           log_probs,
2607           targets,
2608           input_lengths,
2609           target_lengths,
2610           F::CTCLossFuncOptions().reduction(torch::kNone));
2611       ASSERT_TRUE(loss.ge(0).all().item<bool>());
2612       ASSERT_TRUE(torch::allclose(
2613           -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss));
2614     }
2615     {
2616       const auto target_lengths = torch::tensor({0, 9, 0});
2617       const auto input_lengths = torch::tensor({50, 50, 50});
2618       const auto targets = torch::randint(1, 15, {9}, torch::kLong);
2619       const auto log_probs =
2620           torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2621       const auto loss = F::ctc_loss(
2622           log_probs,
2623           targets,
2624           input_lengths,
2625           target_lengths,
2626           F::CTCLossFuncOptions().reduction(torch::kNone));
2627       ASSERT_TRUE(loss.ge(0).all().item<bool>());
2628       ASSERT_TRUE(torch::allclose(
2629           -log_probs.sum(0)
2630                .index_select(0, torch::tensor({0, 2}, torch::kLong))
2631                .slice(1, 0, 1)
2632                .view({2}),
2633           loss.index_select(0, torch::tensor({0, 2}, torch::kLong))));
2634     }
2635   }
2636 }
2637 
TEST_F(FunctionalTest,PoissonNLLLoss)2638 TEST_F(FunctionalTest, PoissonNLLLoss) {
2639   const auto input = torch::tensor({0.5, 1.5, 2.5});
2640   const auto target = torch::tensor({1., 2., 3.});
2641   const auto component_wise_loss = torch::exp(input) - target * input;
2642   ASSERT_TRUE(torch::allclose(
2643       torch::mean(component_wise_loss), F::poisson_nll_loss(input, target)));
2644   ASSERT_TRUE(torch::allclose(
2645       component_wise_loss,
2646       F::poisson_nll_loss(
2647           input,
2648           target,
2649           F::PoissonNLLLossFuncOptions().reduction(torch::kNone))));
2650   ASSERT_TRUE(torch::allclose(
2651       torch::sum(component_wise_loss),
2652       F::poisson_nll_loss(
2653           input,
2654           target,
2655           F::PoissonNLLLossFuncOptions().reduction(torch::kSum))));
2656   ASSERT_TRUE(torch::allclose(
2657       torch::mean(component_wise_loss),
2658       F::poisson_nll_loss(
2659           input,
2660           target,
2661           F::PoissonNLLLossFuncOptions().reduction(torch::kMean))));
2662 }
2663 
TEST_F(FunctionalTest,MarginRankingLoss)2664 TEST_F(FunctionalTest, MarginRankingLoss) {
2665   {
2666     const auto input1 = torch::randn(15) * 10;
2667     const auto input2 = torch::randn(15) * 10;
2668     const auto target = torch::randn(15).sign();
2669     ASSERT_TRUE(torch::allclose(
2670         F::margin_ranking_loss(input1, input2, target),
2671         (-target * (input1 - input2)).clamp(0).mean()));
2672   }
2673   {
2674     const auto input1 = torch::randn(15) * 10;
2675     const auto input2 = torch::randn(15) * 10;
2676     const auto target = torch::randn(15).sign();
2677     const auto margin = 0.5;
2678     ASSERT_TRUE(torch::allclose(
2679         F::margin_ranking_loss(
2680             input1,
2681             input2,
2682             target,
2683             F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2684                 torch::kSum)),
2685         (-target * (input1 - input2) + margin).clamp(0).sum()));
2686   }
2687   {
2688     const auto input1 = torch::randn(15) * 10;
2689     const auto input2 = torch::randn(15) * 10;
2690     const auto target = torch::randn(15).sign();
2691     const auto margin = 0.5;
2692     ASSERT_TRUE(torch::allclose(
2693         F::margin_ranking_loss(
2694             input1,
2695             input2,
2696             target,
2697             F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2698                 torch::kMean)),
2699         (-target * (input1 - input2) + margin).clamp(0).mean()));
2700   }
2701 }
2702 
TEST_F(FunctionalTest,ConvTranspose1d)2703 TEST_F(FunctionalTest, ConvTranspose1d) {
2704   auto x = torch::arange(20.).view({2, 2, 5});
2705   auto weight = torch::arange(18.).view({2, 3, 3});
2706   auto y =
2707       F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
2708   auto expected = torch::tensor(
2709       {{{45., 104., 179., 212., 245., 188., 107.},
2710         {60., 140., 242., 293., 344., 260., 146.},
2711         {75., 176., 305., 374., 443., 332., 185.}},
2712        {{135., 304., 509., 542., 575., 428., 237.},
2713         {210., 460., 752., 803., 854., 620., 336.},
2714         {285., 616., 995., 1064., 1133., 812., 435.}}});
2715   ASSERT_TRUE(torch::allclose(y, expected));
2716 
2717   auto y_no_options = F::conv_transpose1d(x, weight);
2718   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2719 }
2720 
TEST_F(FunctionalTest,ConvTranspose2dEven)2721 TEST_F(FunctionalTest, ConvTranspose2dEven) {
2722   auto x = torch::arange(50.).view({1, 2, 5, 5});
2723   auto weight = torch::arange(54.).view({2, 3, 3, 3});
2724   auto y =
2725       F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2726   auto expected = torch::tensor(
2727       {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
2728          {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
2729          {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
2730          {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
2731          {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
2732          {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
2733          {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
2734         {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
2735          {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
2736          {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
2737          {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
2738          {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
2739          {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
2740          {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
2741         {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
2742          {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
2743          {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
2744          {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
2745          {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
2746          {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
2747          {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
2748   ASSERT_TRUE(torch::allclose(y, expected));
2749 
2750   auto y_no_options = F::conv_transpose2d(x, weight);
2751   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2752 }
2753 
TEST_F(FunctionalTest,ConvTranspose2dUneven)2754 TEST_F(FunctionalTest, ConvTranspose2dUneven) {
2755   auto x = torch::arange(40.).view({1, 2, 5, 4});
2756   auto weight = torch::arange(36.).view({2, 3, 3, 2});
2757   auto y =
2758       F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2759   auto expected = torch::tensor(
2760       {{{{360., 758., 796., 834., 440.},
2761          {832., 1752., 1836., 1920., 1012.},
2762          {1432., 3014., 3152., 3290., 1732.},
2763          {1696., 3566., 3704., 3842., 2020.},
2764          {1960., 4118., 4256., 4394., 2308.},
2765          {1504., 3152., 3252., 3352., 1756.},
2766          {856., 1790., 1844., 1898., 992.}},
2767         {{480., 1010., 1072., 1134., 596.},
2768          {1120., 2352., 2484., 2616., 1372.},
2769          {1936., 4058., 4268., 4478., 2344.},
2770          {2344., 4898., 5108., 5318., 2776.},
2771          {2752., 5738., 5948., 6158., 3208.},
2772          {2080., 4328., 4476., 4624., 2404.},
2773          {1168., 2426., 2504., 2582., 1340.}},
2774         {{600., 1262., 1348., 1434., 752.},
2775          {1408., 2952., 3132., 3312., 1732.},
2776          {2440., 5102., 5384., 5666., 2956.},
2777          {2992., 6230., 6512., 6794., 3532.},
2778          {3544., 7358., 7640., 7922., 4108.},
2779          {2656., 5504., 5700., 5896., 3052.},
2780          {1480., 3062., 3164., 3266., 1688.}}}});
2781   ASSERT_TRUE(torch::allclose(y, expected));
2782 
2783   auto y_no_options = F::conv_transpose2d(x, weight);
2784   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2785 }
2786 
TEST_F(FunctionalTest,ConvTranspose3d)2787 TEST_F(FunctionalTest, ConvTranspose3d) {
2788   auto x = torch::arange(16.).view({1, 2, 2, 2, 2});
2789   auto weight = torch::arange(32.).view({2, 2, 2, 2, 2});
2790   auto y =
2791       F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
2792   auto expected = torch::tensor(
2793       {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
2794          {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
2795          {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
2796         {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
2797          {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
2798          {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
2799   ASSERT_TRUE(torch::allclose(y, expected));
2800 
2801   auto y_no_options = F::conv_transpose3d(x, weight);
2802   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2803 }
2804 
TEST_F(FunctionalTest,AlphaDropout)2805 TEST_F(FunctionalTest, AlphaDropout) {
2806   auto input = torch::randn(5000);
2807   auto input_mean = input.mean();
2808   auto input_std = input.std();
2809 
2810   for (const auto rate : {0.2, 0.5, 0.8}) {
2811     for (const auto inplace : {false, true}) {
2812       auto input_ = input.clone();
2813       auto output = F::alpha_dropout(
2814           input_,
2815           F::AlphaDropoutFuncOptions().p(rate).training(false).inplace(
2816               inplace));
2817       ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2818       ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2819       if (inplace) {
2820         ASSERT_TRUE(torch::allclose(input_, output));
2821       }
2822     }
2823   }
2824   auto output = F::detail::alpha_dropout(input, 0.5, false, false);
2825   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2826   ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2827 }
2828 
TEST_F(FunctionalTest,FeatureAlphaDropout)2829 TEST_F(FunctionalTest, FeatureAlphaDropout) {
2830   auto input = torch::randn(5000);
2831   auto input_mean = input.mean();
2832   auto input_std = input.std();
2833 
2834   for (const auto rate : {0.2, 0.5, 0.8}) {
2835     for (const auto inplace : {false, true}) {
2836       auto input_ = input.clone();
2837       auto output = F::feature_alpha_dropout(
2838           input_,
2839           F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace(
2840               inplace));
2841       ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2842       ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2843       if (inplace) {
2844         ASSERT_TRUE(torch::allclose(input_, output));
2845       }
2846     }
2847   }
2848   auto output = F::feature_alpha_dropout(input);
2849   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2850   ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2851 }
2852 
TEST_F(FunctionalTest,Dropout)2853 TEST_F(FunctionalTest, Dropout) {
2854   auto input = torch::randn(5000);
2855   auto input_mean = input.mean();
2856   auto input_std = input.std();
2857 
2858   for (const auto rate : {0.2, 0.5, 0.8}) {
2859     auto output = F::dropout(input, F::DropoutFuncOptions().p(rate));
2860     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2861     ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2862   }
2863   auto output = F::dropout(input);
2864   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2865   ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2866   ASSERT_TRUE(F::dropout(torch::tensor(1.)).defined());
2867 }
2868 
TEST_F(FunctionalTest,Dropout2d)2869 TEST_F(FunctionalTest, Dropout2d) {
2870   auto input = torch::randn({2, 2, 50, 100});
2871   auto input_mean = input.mean();
2872   auto input_std = input.std();
2873 
2874   for (const auto rate : {0.2, 0.5, 0.8}) {
2875     auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate));
2876     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2877   }
2878   auto output = F::dropout2d(input);
2879   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2880   ASSERT_TRUE(F::dropout2d(torch::randn({2, 50, 100})).defined());
2881 }
2882 
TEST_F(FunctionalTest,Dropout3d)2883 TEST_F(FunctionalTest, Dropout3d) {
2884   auto input = torch::randn({2, 2, 50, 10, 10});
2885   auto input_mean = input.mean();
2886   auto input_std = input.std();
2887 
2888   for (const auto rate : {0.2, 0.5, 0.8}) {
2889     auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate));
2890     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2891   }
2892   auto output = F::dropout3d(input);
2893   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2894   ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined());
2895 }
2896 
2897 template <c10::ScalarType S, typename T>
test_isfinite(const at::Device & device)2898 void test_isfinite(const at::Device& device) {
2899   const std::vector<T> values = {
2900       std::numeric_limits<T>::lowest(),
2901       0,
2902       1,
2903       42,
2904       std::numeric_limits<T>::min(),
2905       std::numeric_limits<T>::max()};
2906   for (const auto value : values) {
2907     const auto x = torch::full(
2908         {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2909     ASSERT_TRUE(torch::isfinite(x).all().template item<bool>());
2910   }
2911   if (std::numeric_limits<T>::has_infinity) {
2912     const auto inf = std::numeric_limits<T>::infinity();
2913     const auto x = torch::tensor(
2914         {-inf,
2915          std::numeric_limits<T>::lowest(),
2916          static_cast<T>(0),
2917          static_cast<T>(1),
2918          static_cast<T>(42),
2919          std::numeric_limits<T>::min(),
2920          std::numeric_limits<T>::max(),
2921          inf},
2922         torch::TensorOptions().dtype(S).device(device));
2923     ASSERT_TRUE(torch::allclose(
2924         // torch::allclose does not support comparing torch::kBool
2925         torch::isfinite(x).toType(torch::kInt),
2926         torch::tensor(
2927             {false, true, true, true, true, true, true, false},
2928             torch::TensorOptions().device(device))
2929             .toType(torch::kInt)));
2930   }
2931   if (std::numeric_limits<T>::has_quiet_NaN) {
2932     const auto x = torch::tensor(
2933         {std::numeric_limits<T>::quiet_NaN()},
2934         torch::TensorOptions().dtype(S).device(device));
2935     ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2936   }
2937   if (std::numeric_limits<T>::has_signaling_NaN) {
2938     const auto x = torch::tensor(
2939         {std::numeric_limits<T>::signaling_NaN()},
2940         torch::TensorOptions().dtype(S).device(device));
2941     ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2942   }
2943 }
2944 
TEST_F(FunctionalTest,isfinite)2945 TEST_F(FunctionalTest, isfinite) {
2946   const at::Device device("cpu");
2947   test_isfinite<torch::kUInt8, uint8_t>(device);
2948   test_isfinite<torch::kInt8, int8_t>(device);
2949   test_isfinite<torch::kInt16, int16_t>(device);
2950   test_isfinite<torch::kInt32, int32_t>(device);
2951   test_isfinite<torch::kInt64, int64_t>(device);
2952   test_isfinite<torch::kFloat32, float>(device);
2953   test_isfinite<torch::kFloat64, double>(device);
2954 }
2955 
TEST_F(FunctionalTest,isfinite_CUDA)2956 TEST_F(FunctionalTest, isfinite_CUDA) {
2957   const at::Device device("cuda");
2958   test_isfinite<torch::kUInt8, uint8_t>(device);
2959   test_isfinite<torch::kInt8, int8_t>(device);
2960   test_isfinite<torch::kInt16, int16_t>(device);
2961   test_isfinite<torch::kInt32, int32_t>(device);
2962   test_isfinite<torch::kInt64, int64_t>(device);
2963   test_isfinite<torch::kFloat32, float>(device);
2964   test_isfinite<torch::kFloat64, double>(device);
2965   test_isfinite<torch::kFloat16, c10::Half>(device);
2966 }
2967 
2968 template <c10::ScalarType S, typename T>
test_isinf(const at::Device & device)2969 void test_isinf(const at::Device& device) {
2970   const std::vector<T> values = {
2971       std::numeric_limits<T>::lowest(),
2972       0,
2973       1,
2974       42,
2975       std::numeric_limits<T>::min(),
2976       std::numeric_limits<T>::max()};
2977   for (const auto value : values) {
2978     const auto x = torch::full(
2979         {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2980     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2981   }
2982   if (std::numeric_limits<T>::has_infinity) {
2983     const auto inf = std::numeric_limits<T>::infinity();
2984     const auto x = torch::tensor(
2985         {-inf,
2986          std::numeric_limits<T>::lowest(),
2987          static_cast<T>(0),
2988          static_cast<T>(1),
2989          static_cast<T>(42),
2990          std::numeric_limits<T>::min(),
2991          std::numeric_limits<T>::max(),
2992          inf},
2993         torch::TensorOptions().dtype(S).device(device));
2994     ASSERT_TRUE(torch::allclose(
2995         // torch::allclose does not support comparing torch::kBool
2996         torch::isinf(x).toType(torch::kInt),
2997         torch::tensor(
2998             {true, false, false, false, false, false, false, true},
2999             torch::TensorOptions().device(device))
3000             .toType(torch::kInt)));
3001   }
3002   if (std::numeric_limits<T>::has_quiet_NaN) {
3003     const auto x = torch::tensor(
3004         {std::numeric_limits<T>::quiet_NaN()},
3005         torch::TensorOptions().dtype(S).device(device));
3006     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3007   }
3008   if (std::numeric_limits<T>::has_signaling_NaN) {
3009     const auto x = torch::tensor(
3010         {std::numeric_limits<T>::signaling_NaN()},
3011         torch::TensorOptions().dtype(S).device(device));
3012     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3013   }
3014 }
3015 
TEST_F(FunctionalTest,isinf)3016 TEST_F(FunctionalTest, isinf) {
3017   const at::Device device("cpu");
3018   test_isinf<torch::kUInt8, uint8_t>(device);
3019   test_isinf<torch::kInt8, int8_t>(device);
3020   test_isinf<torch::kInt16, int16_t>(device);
3021   test_isinf<torch::kInt32, int32_t>(device);
3022   test_isinf<torch::kInt64, int64_t>(device);
3023   test_isinf<torch::kFloat32, float>(device);
3024   test_isinf<torch::kFloat64, double>(device);
3025 }
3026 
TEST_F(FunctionalTest,isinf_CUDA)3027 TEST_F(FunctionalTest, isinf_CUDA) {
3028   const at::Device device("cuda");
3029   test_isinf<torch::kUInt8, uint8_t>(device);
3030   test_isinf<torch::kInt8, int8_t>(device);
3031   test_isinf<torch::kInt16, int16_t>(device);
3032   test_isinf<torch::kInt32, int32_t>(device);
3033   test_isinf<torch::kInt64, int64_t>(device);
3034   test_isinf<torch::kFloat32, float>(device);
3035   test_isinf<torch::kFloat64, double>(device);
3036   test_isinf<torch::kFloat16, c10::Half>(device);
3037 }
3038 
3039 template <c10::ScalarType S, typename T>
test_allclose(const at::Device & device)3040 void test_allclose(const at::Device& device) {
3041   const std::vector<T> values = {
3042       std::numeric_limits<T>::lowest(),
3043       0,
3044       1,
3045       42,
3046       std::numeric_limits<T>::min(),
3047       std::numeric_limits<T>::max()};
3048   for (const auto value : values) {
3049     const auto x =
3050         torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3051     const auto y =
3052         torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3053     ASSERT_TRUE(torch::allclose(x, x));
3054     ASSERT_TRUE(torch::allclose(x, y));
3055     ASSERT_TRUE(torch::allclose(y, x));
3056     ASSERT_FALSE(torch::allclose(1.1 * x + 0.1, 1.0 * x));
3057     ASSERT_TRUE(torch::allclose(0.99 * x + 0.1, 1.0 * x, 1.1, 0.1));
3058   }
3059   if (std::numeric_limits<T>::has_infinity) {
3060     const auto inf = std::numeric_limits<T>::infinity();
3061     const auto x = torch::tensor(
3062         {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3063     const auto y = torch::tensor(
3064         {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3065     ASSERT_TRUE(torch::allclose(x, x));
3066     ASSERT_TRUE(torch::allclose(x, y));
3067     ASSERT_TRUE(torch::allclose(y, x));
3068   }
3069   if (std::numeric_limits<T>::has_quiet_NaN) {
3070     const auto x = torch::tensor(
3071         {std::numeric_limits<T>::quiet_NaN()},
3072         torch::TensorOptions().dtype(S).device(device));
3073     const auto y = torch::tensor(
3074         {std::numeric_limits<T>::quiet_NaN()},
3075         torch::TensorOptions().dtype(S).device(device));
3076     ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3077     ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3078     ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3079   }
3080   if (std::numeric_limits<T>::has_signaling_NaN) {
3081     const auto x = torch::tensor(
3082         {std::numeric_limits<T>::signaling_NaN()},
3083         torch::TensorOptions().dtype(S).device(device));
3084     const auto y = torch::tensor(
3085         {std::numeric_limits<T>::signaling_NaN()},
3086         torch::TensorOptions().dtype(S).device(device));
3087     ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3088     ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3089     ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3090   }
3091 }
3092 
TEST_F(FunctionalTest,AllClose)3093 TEST_F(FunctionalTest, AllClose) {
3094   const at::Device device("cpu");
3095   test_allclose<torch::kUInt8, uint8_t>(device);
3096   test_allclose<torch::kInt8, int8_t>(device);
3097   test_allclose<torch::kInt16, int16_t>(device);
3098   test_allclose<torch::kInt32, int32_t>(device);
3099   test_allclose<torch::kInt64, int64_t>(device);
3100   test_allclose<torch::kFloat32, float>(device);
3101   test_allclose<torch::kFloat64, double>(device);
3102 }
3103 
TEST_F(FunctionalTest,AllClose_CUDA)3104 TEST_F(FunctionalTest, AllClose_CUDA) {
3105   const at::Device device("cuda");
3106   test_allclose<torch::kUInt8, uint8_t>(device);
3107   test_allclose<torch::kInt8, int8_t>(device);
3108   test_allclose<torch::kInt16, int16_t>(device);
3109   test_allclose<torch::kInt32, int32_t>(device);
3110   test_allclose<torch::kInt64, int64_t>(device);
3111   test_allclose<torch::kFloat32, float>(device);
3112   test_allclose<torch::kFloat64, double>(device);
3113   test_allclose<torch::kFloat16, c10::Half>(device);
3114 }
3115 
TEST_F(FunctionalTest,BCEWithLogitsLoss)3116 TEST_F(FunctionalTest, BCEWithLogitsLoss) {
3117   { // test BCE with logits raises if target and input are different size
3118     {
3119       const auto target = torch::rand(5);
3120       const auto input = torch::rand({5, 1});
3121       ASSERT_THROWS_WITH(
3122           F::binary_cross_entropy_with_logits(input, target),
3123           "must be the same as input size");
3124     }
3125     {
3126       const auto target = torch::rand({5, 1});
3127       const auto input = torch::rand(5);
3128       ASSERT_THROWS_WITH(
3129           F::binary_cross_entropy_with_logits(input, target),
3130           "must be the same as input size");
3131     }
3132   }
3133   { // test BCE with logits gives same result as sigmoid and bce loss
3134     auto sigmoid = Sigmoid();
3135 
3136     auto target = torch::rand({64, 4});
3137     auto output = torch::rand({64, 4}) - 0.5;
3138 
3139     ASSERT_TRUE(torch::allclose(
3140         F::binary_cross_entropy_with_logits(output, target),
3141         F::binary_cross_entropy(sigmoid(output), target)));
3142 
3143     auto weight = torch::rand(4);
3144     ASSERT_TRUE(torch::allclose(
3145         F::binary_cross_entropy_with_logits(
3146             output,
3147             target,
3148             F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3149         F::binary_cross_entropy(
3150             sigmoid(output),
3151             target,
3152             F::BinaryCrossEntropyFuncOptions().weight(weight))));
3153 
3154     target = torch::zeros({4, 1}, torch::kFloat);
3155     output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3156 
3157     ASSERT_TRUE(torch::allclose(
3158         F::binary_cross_entropy_with_logits(output, target),
3159         F::binary_cross_entropy(sigmoid(output), target)));
3160 
3161     ASSERT_TRUE(torch::allclose(
3162         F::binary_cross_entropy_with_logits(
3163             output,
3164             target,
3165             F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(
3166                 torch::kNone)),
3167         F::binary_cross_entropy(
3168             sigmoid(output),
3169             target,
3170             F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))));
3171 
3172     weight = torch::rand({1}, torch::kFloat);
3173     ASSERT_TRUE(torch::allclose(
3174         F::binary_cross_entropy_with_logits(
3175             output,
3176             target,
3177             F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3178         F::binary_cross_entropy(
3179             sigmoid(output),
3180             target,
3181             F::BinaryCrossEntropyFuncOptions().weight(weight))));
3182   }
3183   { // test BCE with logits has correct grad at zero
3184     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3185     const auto target = torch::zeros({3, 1});
3186     F::binary_cross_entropy_with_logits(
3187         output,
3188         target,
3189         F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum))
3190         .backward();
3191     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3192     ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3193   }
3194   { // test BCE with logits broadcasts weights
3195     const auto target = torch::rand({16, 4});
3196     const auto output = torch::rand({16, 4}) - 0.5;
3197 
3198     auto weight = torch::rand(4);
3199     auto out1 = F::binary_cross_entropy_with_logits(
3200         output,
3201         target,
3202         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3203 
3204     weight = weight.expand({16, 4}).contiguous();
3205     auto out2 = F::binary_cross_entropy_with_logits(
3206         output,
3207         target,
3208         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3209 
3210     ASSERT_TRUE(torch::allclose(out1, out2));
3211 
3212     weight = torch::rand({16, 1});
3213     out1 = F::binary_cross_entropy_with_logits(
3214         output,
3215         target,
3216         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3217 
3218     weight = weight.expand({16, 4}).contiguous();
3219     out2 = F::binary_cross_entropy_with_logits(
3220         output,
3221         target,
3222         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3223 
3224     ASSERT_TRUE(torch::allclose(out1, out2));
3225   }
3226   { // test BCE with logits ones in pos weights are the same as none
3227     const auto target = torch::rand({64, 4});
3228     const auto output = torch::rand({64, 4}) - 0.5;
3229     const auto pos_weight = torch::ones({64, 4});
3230 
3231     ASSERT_TRUE(torch::allclose(
3232         F::binary_cross_entropy_with_logits(output, target),
3233         F::binary_cross_entropy_with_logits(
3234             output,
3235             target,
3236             F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(
3237                 pos_weight))));
3238   }
3239   { // test BCE with logits broadcasts pos weights
3240     const auto target = torch::rand({64, 4});
3241     const auto output = torch::rand({64, 4}) - 0.5;
3242     const auto pos_weight = torch::rand(4);
3243     const auto out1 = F::binary_cross_entropy_with_logits(
3244         output,
3245         target,
3246         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3247 
3248     const auto pos_weight1 = pos_weight.expand({1, 4});
3249     const auto out2 = F::binary_cross_entropy_with_logits(
3250         output,
3251         target,
3252         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3253 
3254     const auto pos_weight2 = pos_weight.expand({64, 4});
3255     const auto out3 = F::binary_cross_entropy_with_logits(
3256         output,
3257         target,
3258         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3259 
3260     ASSERT_TRUE(torch::allclose(out1, out2));
3261     ASSERT_TRUE(torch::allclose(out1, out3));
3262   }
3263   { // test BCE with logits with pos weight has correct grad at zero
3264     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3265     const auto target = torch::zeros({3, 1});
3266     const auto pos_weight = torch::ones({3, 1});
3267     F::binary_cross_entropy_with_logits(
3268         output,
3269         target,
3270         F::BinaryCrossEntropyWithLogitsFuncOptions()
3271             .pos_weight(pos_weight)
3272             .reduction(torch::kSum))
3273         .backward();
3274     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3275     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3276     const auto grad = output.grad();
3277     ASSERT_TRUE(torch::allclose(grad, expected_grad));
3278   }
3279   { // test BCE with logits stability
3280     const auto output = torch::tensor({0., -120.});
3281     const auto target = torch::tensor({0., 1.});
3282     const auto pos_weight = torch::tensor({1., 1.});
3283 
3284     const auto out1 = F::binary_cross_entropy_with_logits(output, target);
3285     ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3286 
3287     const auto out2 = F::binary_cross_entropy_with_logits(
3288         output,
3289         target,
3290         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3291     ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3292   }
3293 }
3294