1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5
6 #include <test/cpp/api/support.h>
7
8 namespace F = torch::nn::functional;
9
10 using namespace torch::nn;
11
12 struct FunctionalTest : torch::test::SeedingFixture {};
13
TEST_F(FunctionalTest,Conv1d)14 TEST_F(FunctionalTest, Conv1d) {
15 auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
16 .reshape({2, 3, 5});
17 auto weight =
18 torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true))
19 .reshape({2, 3, 3});
20 auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
21 auto expected = torch::tensor(
22 {{{312., 348., 384.}, {798., 915., 1032.}},
23
24 {{852., 888., 924.}, {2553., 2670., 2787.}}},
25 torch::kFloat);
26 ASSERT_TRUE(torch::allclose(y, expected));
27
28 auto y_no_options = F::conv1d(x, weight);
29 ASSERT_TRUE(torch::allclose(y_no_options, expected));
30 }
31
TEST_F(FunctionalTest,Conv2dEven)32 TEST_F(FunctionalTest, Conv2dEven) {
33 auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
34 .reshape({1, 3, 5, 5});
35 auto weight =
36 torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true))
37 .reshape({2, 3, 3, 3});
38 auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
39 auto expected = torch::tensor(
40 {{{{15219., 15570., 15921.},
41 {16974., 17325., 17676.},
42 {18729., 19080., 19431.}},
43
44 {{37818., 38898., 39978.},
45 {43218., 44298., 45378.},
46 {48618., 49698., 50778.}}}},
47 torch::kFloat);
48 ASSERT_TRUE(torch::allclose(y, expected));
49
50 auto y_no_options = F::conv2d(x, weight);
51 ASSERT_TRUE(torch::allclose(y_no_options, expected));
52 }
53
TEST_F(FunctionalTest,Conv2dUneven)54 TEST_F(FunctionalTest, Conv2dUneven) {
55 auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
56 .reshape({1, 3, 5, 4});
57 auto weight =
58 torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true))
59 .reshape({2, 3, 3, 2});
60 auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
61 auto expected = torch::tensor(
62 {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
63
64 {{13227., 13704., 14181.},
65 {15135., 15612., 16089.},
66 {17043., 17520., 17997.}}}},
67 torch::kFloat);
68 ASSERT_TRUE(torch::allclose(y, expected));
69
70 auto y_no_options = F::conv2d(x, weight);
71 ASSERT_TRUE(torch::allclose(y_no_options, expected));
72 }
73
TEST_F(FunctionalTest,Conv3d)74 TEST_F(FunctionalTest, Conv3d) {
75 auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
76 .reshape({1, 3, 5, 5, 5});
77 auto weight =
78 torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true))
79 .reshape({2, 3, 3, 3, 3});
80 auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
81 auto expected = torch::tensor(
82 {{{{{700704., 703944., 707184.},
83 {716904., 720144., 723384.},
84 {733104., 736344., 739584.}},
85
86 {{781704., 784944., 788184.},
87 {797904., 801144., 804384.},
88 {814104., 817344., 820584.}},
89
90 {{862704., 865944., 869184.},
91 {878904., 882144., 885384.},
92 {895104., 898344., 901584.}}},
93
94 {{{1724220., 1734021., 1743822.},
95 {1773225., 1783026., 1792827.},
96 {1822230., 1832031., 1841832.}},
97
98 {{1969245., 1979046., 1988847.},
99 {2018250., 2028051., 2037852.},
100 {2067255., 2077056., 2086857.}},
101
102 {{2214270., 2224071., 2233872.},
103 {2263275., 2273076., 2282877.},
104 {2312280., 2322081., 2331882.}}}}},
105 torch::kFloat);
106 ASSERT_TRUE(torch::allclose(y, expected));
107
108 auto y_no_options = F::conv3d(x, weight);
109 ASSERT_TRUE(torch::allclose(y_no_options, expected));
110 }
111
TEST_F(FunctionalTest,MaxPool1d)112 TEST_F(FunctionalTest, MaxPool1d) {
113 auto x = torch::ones({1, 1, 5});
114 auto y = F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2));
115
116 ASSERT_EQ(y.ndimension(), 3);
117 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
118 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
119 }
120
TEST_F(FunctionalTest,MaxPool2d)121 TEST_F(FunctionalTest, MaxPool2d) {
122 auto x = torch::ones({2, 5, 5});
123 auto y = F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2));
124
125 ASSERT_EQ(y.ndimension(), 3);
126 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
127 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
128 }
129
TEST_F(FunctionalTest,MaxPool2dBackward)130 TEST_F(FunctionalTest, MaxPool2dBackward) {
131 auto input = torch::rand(
132 {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true));
133 auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2));
134 auto s = output.sum();
135 s.backward();
136 ASSERT_TRUE(input.sizes() == input.grad().sizes());
137 }
138
TEST_F(FunctionalTest,MaxPool3d)139 TEST_F(FunctionalTest, MaxPool3d) {
140 auto x = torch::ones({2, 5, 5, 5});
141 auto y = F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2));
142
143 ASSERT_EQ(y.ndimension(), 4);
144 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
145 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
146 }
147
TEST_F(FunctionalTest,AvgPool1d)148 TEST_F(FunctionalTest, AvgPool1d) {
149 auto x = torch::ones({1, 1, 5});
150 auto y = F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2));
151
152 ASSERT_EQ(y.ndimension(), 3);
153 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
154 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
155 }
156
TEST_F(FunctionalTest,AvgPool2d)157 TEST_F(FunctionalTest, AvgPool2d) {
158 auto x = torch::ones({2, 5, 5});
159 auto y = F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2));
160
161 ASSERT_EQ(y.ndimension(), 3);
162 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
163 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
164 }
165
TEST_F(FunctionalTest,AvgPool3d)166 TEST_F(FunctionalTest, AvgPool3d) {
167 auto x = torch::ones({2, 5, 5, 5});
168 auto y = F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2));
169
170 ASSERT_EQ(y.ndimension(), 4);
171 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
172 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
173 }
174
TEST_F(FunctionalTest,FractionalMaxPool2d)175 TEST_F(FunctionalTest, FractionalMaxPool2d) {
176 auto x = torch::ones({2, 5, 5});
177 auto y = F::fractional_max_pool2d(
178 x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
179
180 ASSERT_EQ(y.ndimension(), 3);
181 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
182 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
183
184 auto y_with_indices = F::fractional_max_pool2d_with_indices(
185 x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
186 ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
187 ASSERT_TRUE(torch::allclose(
188 std::get<1>(y_with_indices),
189 torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
190 ASSERT_EQ(
191 std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2}));
192
193 auto x1 = torch::ones({2, 2, 5, 5});
194 auto y1 = F::fractional_max_pool2d(
195 x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
196
197 ASSERT_EQ(y1.ndimension(), 4);
198 ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2})));
199 ASSERT_EQ(y1.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
200
201 auto y1_with_indices = F::fractional_max_pool2d_with_indices(
202 x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
203 ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices)));
204 ASSERT_TRUE(torch::allclose(
205 std::get<1>(y1_with_indices),
206 torch::tensor(
207 {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}},
208 {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}})));
209 ASSERT_EQ(
210 std::get<1>(y1_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
211 }
212
TEST_F(FunctionalTest,FractionalMaxPool3d)213 TEST_F(FunctionalTest, FractionalMaxPool3d) {
214 auto x = torch::ones({2, 5, 5, 5});
215 auto y = F::fractional_max_pool3d(
216 x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
217
218 ASSERT_EQ(y.ndimension(), 4);
219 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
220 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
221
222 auto y_with_indices = F::fractional_max_pool3d_with_indices(
223 x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
224 ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
225 ASSERT_TRUE(torch::allclose(
226 std::get<1>(y_with_indices),
227 torch::tensor(
228 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
229 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
230 ASSERT_EQ(
231 std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
232 }
233
TEST_F(FunctionalTest,LPPool1d)234 TEST_F(FunctionalTest, LPPool1d) {
235 int norm_type = 2;
236 int stride = 2;
237 int kernel_size = 3;
238
239 auto x = torch::ones({1, 1, 5});
240 auto y = F::lp_pool1d(
241 x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride));
242 auto expected =
243 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
244 kernel_size)
245 .pow(1. / norm_type);
246
247 ASSERT_EQ(y.ndimension(), 3);
248 ASSERT_TRUE(torch::allclose(y, expected));
249 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
250 }
251
TEST_F(FunctionalTest,LPPool2d)252 TEST_F(FunctionalTest, LPPool2d) {
253 int norm_type = 2;
254 int stride = 2;
255 std::vector<int64_t> kernel_size({2, 3});
256
257 auto x = torch::ones({1, 1, 2, 5});
258 auto y = F::lp_pool2d(
259 x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
260 auto expected =
261 (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
262 (kernel_size[0] * kernel_size[1]))
263 .pow(1. / norm_type);
264
265 ASSERT_EQ(y.ndimension(), 4);
266 ASSERT_TRUE(torch::allclose(y, expected));
267 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
268 }
269
TEST_F(FunctionalTest,LPPool3d)270 TEST_F(FunctionalTest, LPPool3d) {
271 int norm_type = 2;
272 int stride = 2;
273 std::vector<int64_t> kernel_size({1, 2, 3});
274
275 auto x = torch::ones({1, 1, 1, 2, 5});
276 auto y = F::lp_pool3d(
277 x, F::LPPool3dFuncOptions(norm_type, kernel_size).stride(stride));
278 auto expected =
279 (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
280 (kernel_size[0] * kernel_size[1] * kernel_size[2]))
281 .pow(1. / norm_type);
282
283 ASSERT_EQ(y.ndimension(), 5);
284 ASSERT_TRUE(torch::allclose(y, expected));
285 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
286 }
287
TEST_F(FunctionalTest,CosineSimilarity)288 TEST_F(FunctionalTest, CosineSimilarity) {
289 auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
290 auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
291 auto output = F::cosine_similarity(
292 input1, input2, F::CosineSimilarityFuncOptions().dim(1));
293 auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
294 ASSERT_TRUE(output.allclose(expected, 1e-04));
295 }
296
TEST_F(FunctionalTest,SmoothL1LossDefaultOptions)297 TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
298 auto input = torch::tensor(
299 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
300 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
301 auto output = F::smooth_l1_loss(input, target);
302 auto expected = torch::tensor(0.0233335, torch::kFloat);
303 auto s = output.sum();
304 s.backward();
305 ASSERT_TRUE(output.allclose(expected));
306 ASSERT_TRUE(input.sizes() == input.grad().sizes());
307 }
308
TEST_F(FunctionalTest,SmoothL1LossBeta)309 TEST_F(FunctionalTest, SmoothL1LossBeta) {
310 auto input = torch::tensor(
311 {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
312 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
313 auto output =
314 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment)
315 F::smooth_l1_loss(
316 input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
317 auto expected = torch::tensor(1.67, torch::kFloat);
318 auto s = output.sum();
319 s.backward();
320 ASSERT_TRUE(output.allclose(expected));
321 ASSERT_TRUE(input.sizes() == input.grad().sizes());
322 }
323
TEST_F(FunctionalTest,SmoothL1LossBetaOptions)324 TEST_F(FunctionalTest, SmoothL1LossBetaOptions) {
325 auto input = torch::tensor(
326 {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
327 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
328 auto output =
329 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
330 F::smooth_l1_loss(
331 input,
332 target,
333 F::SmoothL1LossFuncOptions().reduction(torch::kMean).beta(0.5));
334 auto expected = torch::tensor(1.67, torch::kFloat);
335 auto s = output.sum();
336 s.backward();
337 ASSERT_TRUE(output.allclose(expected));
338 ASSERT_TRUE(input.sizes() == input.grad().sizes());
339 }
340
TEST_F(FunctionalTest,SmoothL1LossNoReduction)341 TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
342 auto input = torch::tensor(
343 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
344 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
345 auto output =
346 // NOLINTNEXTLINE(bugprone-argument-comment)
347 F::smooth_l1_loss(input, target, /*reduction=*/torch::kNone);
348 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
349 auto s = output.sum();
350 s.backward();
351 ASSERT_TRUE(output.allclose(expected));
352 ASSERT_TRUE(input.sizes() == input.grad().sizes());
353 }
354
TEST_F(FunctionalTest,HuberLossDefaultOptions)355 TEST_F(FunctionalTest, HuberLossDefaultOptions) {
356 auto input = torch::tensor(
357 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
358 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
359 auto output = F::huber_loss(input, target);
360 auto expected = torch::tensor(0.0233335, torch::kFloat);
361 auto s = output.sum();
362 s.backward();
363 ASSERT_TRUE(output.allclose(expected));
364 ASSERT_TRUE(input.sizes() == input.grad().sizes());
365 }
366
TEST_F(FunctionalTest,HuberLossDelta)367 TEST_F(FunctionalTest, HuberLossDelta) {
368 auto input = torch::tensor(
369 {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
370 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
371 auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5);
372 auto output = F::huber_loss(input, target, options);
373 auto expected = torch::tensor(1.67 * 0.5, torch::kFloat);
374 auto s = output.sum();
375 s.backward();
376 ASSERT_TRUE(output.allclose(expected));
377 ASSERT_TRUE(input.sizes() == input.grad().sizes());
378 }
379
TEST_F(FunctionalTest,HuberLossNoReduction)380 TEST_F(FunctionalTest, HuberLossNoReduction) {
381 auto input = torch::tensor(
382 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
383 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
384 auto options = F::HuberLossFuncOptions().reduction(torch::kNone);
385 auto output = F::huber_loss(input, target, options);
386 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
387 auto s = output.sum();
388 s.backward();
389 ASSERT_TRUE(output.allclose(expected));
390 ASSERT_TRUE(input.sizes() == input.grad().sizes());
391 }
392
TEST_F(FunctionalTest,SoftMarginLossDefaultOptions)393 TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) {
394 auto input = torch::tensor(
395 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
396 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
397 auto output = F::soft_margin_loss(input, target);
398 auto expected = torch::tensor({1.3767317}, torch::kFloat);
399 auto s = output.sum();
400 s.backward();
401
402 ASSERT_TRUE(output.allclose(expected));
403 ASSERT_EQ(input.sizes(), input.grad().sizes());
404 }
405
TEST_F(FunctionalTest,MultiLabelSoftMarginLossDefaultOptions)406 TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
407 auto input = torch::tensor(
408 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
409 torch::dtype(torch::kFloat).requires_grad(true));
410 auto target =
411 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
412 auto output = F::multilabel_soft_margin_loss(input, target);
413 auto expected = torch::tensor({0.7608436}, torch::kFloat);
414 auto s = output.sum();
415 s.backward();
416
417 ASSERT_TRUE(output.allclose(expected));
418 ASSERT_EQ(input.sizes(), input.grad().sizes());
419 }
420
TEST_F(FunctionalTest,SoftMarginLossNoReduction)421 TEST_F(FunctionalTest, SoftMarginLossNoReduction) {
422 auto input = torch::tensor(
423 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
424 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
425 auto output = F::soft_margin_loss(input, target, torch::kNone);
426 auto expected = torch::tensor(
427 {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
428 auto s = output.sum();
429 s.backward();
430
431 ASSERT_TRUE(output.allclose(expected));
432 ASSERT_EQ(input.sizes(), input.grad().sizes());
433 }
434
TEST_F(FunctionalTest,MultiLabelSoftMarginLossWeightedNoReduction)435 TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) {
436 auto input = torch::tensor(
437 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
438 torch::dtype(torch::kFloat).requires_grad(true));
439 auto target =
440 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
441 auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
442 auto options = F::MultilabelSoftMarginLossFuncOptions()
443 .reduction(torch::kNone)
444 .weight(weight);
445 auto output = F::multilabel_soft_margin_loss(input, target, options);
446 auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
447 auto s = output.sum();
448 s.backward();
449
450 ASSERT_TRUE(output.allclose(expected));
451 ASSERT_EQ(input.sizes(), input.grad().sizes());
452 }
453
TEST_F(FunctionalTest,PairwiseDistance)454 TEST_F(FunctionalTest, PairwiseDistance) {
455 auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
456 auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
457 auto output = F::pairwise_distance(
458 input1, input2, F::PairwiseDistanceFuncOptions().p(1));
459 auto expected = torch::tensor({6, 6}, torch::kFloat);
460 ASSERT_TRUE(output.allclose(expected));
461 }
462
TEST_F(FunctionalTest,PDist)463 TEST_F(FunctionalTest, PDist) {
464 {
465 auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}});
466 auto output = F::pdist(input);
467 auto expected = torch::tensor({11.7898});
468 ASSERT_TRUE(output.allclose(expected));
469 }
470 {
471 auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}});
472 auto output = F::pdist(input, 1.5);
473 auto expected = torch::tensor({4.0, 4.8945, 2.0});
474 ASSERT_TRUE(output.allclose(expected));
475 }
476 }
477
TEST_F(FunctionalTest,AdaptiveMaxPool1d)478 TEST_F(FunctionalTest, AdaptiveMaxPool1d) {
479 auto x = torch::ones({1, 1, 5});
480 auto y = F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3));
481
482 ASSERT_EQ(y.ndimension(), 3);
483 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
484 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
485 }
486
TEST_F(FunctionalTest,AdaptiveMaxPool2d)487 TEST_F(FunctionalTest, AdaptiveMaxPool2d) {
488 auto x = torch::ones({2, 5, 5});
489 auto y = F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3));
490
491 ASSERT_EQ(y.ndimension(), 3);
492 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
493 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
494 }
495
TEST_F(FunctionalTest,AdaptiveMaxPool3d)496 TEST_F(FunctionalTest, AdaptiveMaxPool3d) {
497 auto x = torch::ones({2, 5, 5, 5});
498 auto y = F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3));
499
500 ASSERT_EQ(y.ndimension(), 4);
501 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
502 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
503 }
504
TEST_F(FunctionalTest,AdaptiveAvgPool1d)505 TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
506 auto x = torch::ones({1, 1, 5});
507 auto y = F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3));
508
509 ASSERT_EQ(y.ndimension(), 3);
510 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
511 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
512 }
513
TEST_F(FunctionalTest,AdaptiveAvgPool2d)514 TEST_F(FunctionalTest, AdaptiveAvgPool2d) {
515 auto x = torch::ones({2, 5, 5});
516 auto y = F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3));
517
518 ASSERT_EQ(y.ndimension(), 3);
519 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
520 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
521 }
522
TEST_F(FunctionalTest,AdaptiveAvgPool3d)523 TEST_F(FunctionalTest, AdaptiveAvgPool3d) {
524 auto x = torch::ones({2, 5, 5, 5});
525 auto y = F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3));
526
527 ASSERT_EQ(y.ndimension(), 4);
528 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
529 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
530 }
531
TEST_F(FunctionalTest,L1Loss)532 TEST_F(FunctionalTest, L1Loss) {
533 auto input = torch::randn({5, 6}, torch::requires_grad());
534 auto target = torch::empty({5, 6}).random_(2);
535 auto output = F::l1_loss(torch::sigmoid(input), target);
536 auto s = output.sum();
537 s.backward();
538
539 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
540 ASSERT_EQ(input.sizes(), input.grad().sizes());
541 }
542
TEST_F(FunctionalTest,MSELoss)543 TEST_F(FunctionalTest, MSELoss) {
544 auto input = torch::randn({5, 6}, torch::requires_grad());
545 auto target = torch::empty({5, 6}).random_(2);
546 auto output = F::mse_loss(torch::sigmoid(input), target);
547 auto s = output.sum();
548 s.backward();
549
550 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
551 ASSERT_EQ(input.sizes(), input.grad().sizes());
552 }
553
TEST_F(FunctionalTest,BCELoss)554 TEST_F(FunctionalTest, BCELoss) {
555 auto input = torch::randn({5, 6}, torch::requires_grad());
556 auto target = torch::empty({5, 6}).random_(2);
557 auto output = F::binary_cross_entropy(torch::sigmoid(input), target);
558 auto s = output.sum();
559 s.backward();
560
561 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
562 ASSERT_EQ(input.sizes(), input.grad().sizes());
563 }
564
TEST_F(FunctionalTest,KLDivLoss)565 TEST_F(FunctionalTest, KLDivLoss) {
566 KLDivLoss loss;
567 auto input = torch::randn({5, 6}, torch::requires_grad());
568 auto target = torch::empty({5, 6}).random_(2);
569 auto output = F::kl_div(torch::sigmoid(input), target);
570 auto s = output.sum();
571 s.backward();
572
573 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
574 ASSERT_EQ(input.sizes(), input.grad().sizes());
575 }
576
TEST_F(FunctionalTest,HingeEmbeddingLoss)577 TEST_F(FunctionalTest, HingeEmbeddingLoss) {
578 auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat);
579 auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
580 auto output = F::hinge_embedding_loss(
581 input, target, F::HingeEmbeddingLossFuncOptions().margin(2));
582 auto expected = torch::tensor({10}, torch::kFloat);
583
584 ASSERT_TRUE(output.allclose(expected));
585 }
586
TEST_F(FunctionalTest,GridSample)587 TEST_F(FunctionalTest, GridSample) {
588 auto input =
589 torch::arange(9, torch::kFloat).view(std::vector<int64_t>({1, 1, 3, 3}));
590 auto grid = torch::tensor(
591 {{{{-2., -1.}, {-1., -1.}, {0., -1.}},
592 {{-1., 0.}, {0., 0.}, {1., 0.}},
593 {{0., 1.}, {1., 1.}, {2., 1.}}}},
594 torch::kFloat);
595
596 // bilinear, zeros, true
597 auto options = F::GridSampleFuncOptions()
598 .mode(torch::kBilinear)
599 .padding_mode(torch::kZeros)
600 .align_corners(true);
601 auto output = F::grid_sample(input, grid, options);
602 auto expected = torch::tensor(
603 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
604
605 ASSERT_TRUE(output.allclose(expected));
606
607 // bilinear, zeros, false
608 options = F::GridSampleFuncOptions()
609 .mode(torch::kBilinear)
610 .padding_mode(torch::kZeros)
611 .align_corners(false);
612 output = F::grid_sample(input, grid, options);
613 expected = torch::tensor(
614 {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat);
615
616 ASSERT_TRUE(output.allclose(expected));
617
618 // default options (bilinear, zeros, false) same result as above
619 output = F::grid_sample(input, grid);
620
621 ASSERT_TRUE(output.allclose(expected));
622
623 // nearest, zeros, true
624 options = F::GridSampleFuncOptions()
625 .mode(torch::kNearest)
626 .padding_mode(torch::kZeros)
627 .align_corners(true);
628 output = F::grid_sample(input, grid, options);
629 expected = torch::tensor(
630 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
631
632 ASSERT_TRUE(output.allclose(expected));
633
634 // bilinear, border, true
635 options = F::GridSampleFuncOptions()
636 .mode(torch::kBilinear)
637 .padding_mode(torch::kBorder)
638 .align_corners(true);
639 output = F::grid_sample(input, grid, options);
640 expected = torch::tensor(
641 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat);
642
643 ASSERT_TRUE(output.allclose(expected));
644
645 // bilinear, reflection, true
646 options = F::GridSampleFuncOptions()
647 .mode(torch::kBilinear)
648 .padding_mode(torch::kReflection)
649 .align_corners(true);
650 output = F::grid_sample(input, grid, options);
651 expected = torch::tensor(
652 {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat);
653
654 ASSERT_TRUE(output.allclose(expected));
655 }
656
TEST_F(FunctionalTest,AffineGrid)657 TEST_F(FunctionalTest, AffineGrid) {
658 {
659 // 2D affine.
660 auto theta = torch::arange(1., 13).view(std::vector<int64_t>({2, 2, 3}));
661 auto size = std::vector<int64_t>({2, 3, 2, 2});
662 auto align_corners = true;
663 auto output = F::affine_grid(theta, size, !align_corners);
664 auto expected = torch::tensor(
665 {{{{1.50, 1.50}, {2.50, 5.50}}, {{3.50, 6.50}, {4.50, 10.50}}},
666 {{{1.50, 1.50}, {8.50, 11.50}}, {{9.50, 12.50}, {16.50, 22.50}}}});
667 auto output_aligned = F::affine_grid(theta, size, align_corners);
668 auto expected_aligned = torch::tensor(
669 {{{{0.0, -3.0}, {2.0, 5.0}}, {{4.0, 7.0}, {6.0, 15.0}}},
670 {{{-6.0, -9.0}, {8.0, 11.0}}, {{10.0, 13.0}, {24.0, 33.0}}}});
671
672 ASSERT_TRUE(output.allclose(expected));
673 ASSERT_TRUE(output_aligned.allclose(expected_aligned));
674 }
675 {
676 // 3D affine.
677 auto theta = torch::arange(1., 13).view(std::vector<int64_t>({1, 3, 4}));
678 auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
679 auto align_corners = true;
680 auto output = F::affine_grid(theta, size, !align_corners);
681 auto expected = torch::tensor(
682 {{{{{0.5000, -2.1667, -4.8333}, {1.5000, 2.8333, 4.1667}},
683 {{2.5000, 3.8333, 5.1667}, {3.5000, 8.8333, 14.1667}}},
684 {{{2.5000, 2.5000, 2.5000}, {3.5000, 7.5000, 11.5000}},
685 {{4.5000, 8.5000, 12.5000}, {5.5000, 13.5000, 21.5000}}},
686 {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}},
687 {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}});
688 auto output_aligned = F::affine_grid(theta, size, align_corners);
689 auto expected_aligned = torch::tensor(
690 {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}},
691 {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}},
692 {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}},
693 {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}},
694 {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}},
695 {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}});
696
697 ASSERT_TRUE(output.allclose(expected, 1e-2));
698 ASSERT_TRUE(output_aligned.allclose(expected_aligned));
699 }
700 {
701 auto theta = torch::empty({1, 2, 3}, torch::kDouble);
702 auto size = std::vector<int64_t>({1, 1, 2, 2});
703 ASSERT_THROWS_WITH(
704 F::affine_grid(torch::empty({2, 2, 3}), {-1, 1, 2, 2}),
705 "Expected non-zero, positive output size. Got [-1, 1, 2, 2]");
706 ASSERT_THROWS_WITH(
707 F::affine_grid(torch::empty({2, 2, 3}, torch::kInt), size),
708 "Expected theta to have floating point type, but got int");
709 ASSERT_THROWS_WITH(
710 F::affine_grid(theta[0], size),
711 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
712 "[1, 1, 2, 2]. Got [2, 3].");
713 ASSERT_THROWS_WITH(
714 F::affine_grid(theta.unsqueeze(0), size),
715 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
716 "[1, 1, 2, 2]. Got [1, 1, 2, 3].");
717 ASSERT_THROWS_WITH(
718 F::affine_grid(theta.repeat({1, 2, 1}), size),
719 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
720 "[1, 1, 2, 2]. Got [1, 4, 3].");
721 ASSERT_THROWS_WITH(
722 F::affine_grid(theta.repeat({1, 1, 2}), size),
723 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
724 "[1, 1, 2, 2]. Got [1, 2, 6].");
725 }
726 {
727 auto theta = torch::empty({1, 3, 4}, torch::kDouble);
728 auto size = std::vector<int64_t>({1, 1, 2, 2, 3});
729 ASSERT_THROWS_WITH(
730 F::affine_grid(theta[0], size),
731 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
732 "[1, 1, 2, 2, 3]. Got [3, 4].");
733 ASSERT_THROWS_WITH(
734 F::affine_grid(theta.unsqueeze(0), size),
735 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
736 "[1, 1, 2, 2, 3]. Got [1, 1, 3, 4].");
737 ASSERT_THROWS_WITH(
738 F::affine_grid(theta.repeat({1, 2, 1}), size),
739 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
740 "[1, 1, 2, 2, 3]. Got [1, 6, 4].");
741 ASSERT_THROWS_WITH(
742 F::affine_grid(theta.repeat({1, 1, 2}), size),
743 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
744 "[1, 1, 2, 2, 3]. Got [1, 3, 8].");
745 ASSERT_THROWS_WITH(
746 F::affine_grid(theta, {1, 1, 1, 2, 2, 3}),
747 "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
748 "transforms, respectively. Got size [1, 1, 1, 2, 2, 3]");
749 ASSERT_THROWS_WITH(
750 F::affine_grid(theta, {1, 1}),
751 "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
752 "transforms, respectively. Got size [1, 1]");
753 }
754 }
755
TEST_F(FunctionalTest,MultiMarginLoss)756 TEST_F(FunctionalTest, MultiMarginLoss) {
757 auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
758 auto input = torch::tensor(
759 {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
760 torch::dtype(torch::kFloat).requires_grad(true));
761 auto target = torch::tensor({2, 1, 0}, torch::kLong);
762 auto output = F::multi_margin_loss(
763 input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight));
764 auto expected = torch::tensor({0.305556}, torch::kFloat);
765
766 ASSERT_TRUE(output.allclose(expected, 1e-04));
767 }
768
TEST_F(FunctionalTest,CosineEmbeddingLoss)769 TEST_F(FunctionalTest, CosineEmbeddingLoss) {
770 auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}});
771 auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}});
772 auto target = torch::tensor({1, -1});
773 auto output = F::cosine_embedding_loss(
774 input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5));
775 auto expected = torch::tensor({0.1004}, torch::kFloat);
776
777 ASSERT_TRUE(output.allclose(expected, 1e-4));
778 }
779
TEST_F(FunctionalTest,MultiLabelMarginLossDefaultOptions)780 TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) {
781 auto input = torch::tensor(
782 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
783 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
784 auto output = F::multilabel_margin_loss(input, target);
785 auto expected = torch::tensor({0.8500}, torch::kFloat);
786 auto s = output.sum();
787 s.backward();
788
789 ASSERT_TRUE(output.allclose(expected));
790 ASSERT_EQ(input.sizes(), input.grad().sizes());
791 }
792
TEST_F(FunctionalTest,MultiLabelMarginLossNoReduction)793 TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) {
794 auto input = torch::tensor(
795 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
796 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
797 auto output = F::multilabel_margin_loss(input, target, torch::kNone);
798 auto expected = torch::tensor({0.8500}, torch::kFloat);
799 auto s = output.sum();
800 s.backward();
801
802 ASSERT_TRUE(output.allclose(expected));
803 ASSERT_EQ(input.sizes(), input.grad().sizes());
804 }
805
TEST_F(FunctionalTest,TripletMarginLoss)806 TEST_F(FunctionalTest, TripletMarginLoss) {
807 auto anchor = torch::tensor({{3., 3.}}, torch::kFloat);
808 auto positive = torch::tensor({{2., 2.}}, torch::kFloat);
809 auto negative = torch::tensor({{0., 0.}}, torch::kFloat);
810 auto output = F::triplet_margin_loss(
811 anchor,
812 positive,
813 negative,
814 F::TripletMarginLossFuncOptions().margin(1.0));
815 auto expected = torch::tensor({0.}, torch::kFloat);
816
817 ASSERT_TRUE(output.allclose(expected, 1e-04));
818 }
819
TEST_F(FunctionalTest,TripletMarginWithDistanceLossDefaultParity)820 TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
821 // Check that if we use torch::pairwise_distance with the default
822 // TripletMarginLoss options as our distance function, the outputs
823 // are equal (i.e., equal under defaults).
824
825 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
826 torch::kSum, torch::kMean, torch::kNone};
827 std::vector<float> margins = {0.5, 1.0, 1.5};
828 std::vector<bool> swaps = {true, false};
829
830 for (auto& reduction : reductions) {
831 for (auto& margin : margins) {
832 for (const auto& swap : swaps) {
833 auto anchor = torch::randn(
834 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
835 auto positive = torch::randn(
836 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
837 auto negative = torch::randn(
838 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
839
840 auto basicOptions = F::TripletMarginLossFuncOptions()
841 .reduction(reduction)
842 .margin(margin)
843 .swap(swap);
844 auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions()
845 .reduction(reduction)
846 .margin(margin)
847 .swap(swap);
848 TripletMarginLoss basicLoss(basicOptions);
849 TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
850
851 auto basicOutput =
852 F::triplet_margin_loss(anchor, positive, negative, basicOptions);
853 auto distanceOutput = F::triplet_margin_with_distance_loss(
854 anchor, positive, negative, distanceOptions);
855
856 ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
857
858 // handle for torch::kNone reduction
859 auto sum = distanceOutput.sum();
860 sum.backward();
861 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
862 ASSERT_EQ(positive.sizes(), positive.grad().sizes());
863 ASSERT_EQ(negative.sizes(), negative.grad().sizes());
864 }
865 }
866 }
867 }
868
TEST_F(FunctionalTest,NLLLoss)869 TEST_F(FunctionalTest, NLLLoss) {
870 auto input = torch::tensor(
871 {{-0.1315, -3.1315, -2.5315},
872 {-3.7038, -0.1038, -2.6038},
873 {-2.3422, -1.3422, -0.4422}},
874 torch::kFloat);
875 auto target = torch::tensor({1, 0, 2}, torch::kLong);
876 auto output = F::nll_loss(
877 input,
878 target,
879 F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
880 auto expected = torch::tensor(2.4258, torch::kFloat);
881 ASSERT_TRUE(output.allclose(expected, 1e-04));
882 ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04));
883 }
884
TEST_F(FunctionalTest,CrossEntropy)885 TEST_F(FunctionalTest, CrossEntropy) {
886 auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat);
887 auto target = torch::tensor({0, 1}, torch::kLong);
888 auto output = F::cross_entropy(
889 input,
890 target,
891 F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
892 auto expected = torch::tensor(0.6931, torch::kFloat);
893
894 ASSERT_TRUE(output.allclose(expected, 1e-04));
895 ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04));
896
897 // label smoothing with class indices
898 input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat);
899 output = F::cross_entropy(
900 input,
901 target,
902 F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(
903 torch::kMean));
904 expected = torch::tensor(0.3326, torch::kFloat);
905 ASSERT_TRUE(output.allclose(expected, 1e-04));
906
907 // label smoothing with target probabilities
908 target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
909 output = F::cross_entropy(
910 input,
911 target,
912 F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(
913 torch::kMean));
914 expected = torch::tensor(0.5701, torch::kFloat);
915 ASSERT_TRUE(output.allclose(expected, 1e-04));
916 }
917
TEST_F(FunctionalTest,MaxUnpool1d)918 TEST_F(FunctionalTest, MaxUnpool1d) {
919 auto x = torch::tensor(
920 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
921 auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
922 auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
923
924 ASSERT_EQ(y.ndimension(), 3);
925 ASSERT_TRUE(torch::allclose(
926 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
927 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
928
929 x = torch::tensor(
930 {{2, 4, 5}}, torch::dtype(torch::kFloat).requires_grad(true));
931 indices = torch::tensor({{1, 3, 4}}, torch::kLong);
932 y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
933
934 ASSERT_EQ(y.ndimension(), 2);
935 ASSERT_TRUE(torch::allclose(
936 y, torch::tensor({{0, 2, 0, 4, 5, 0, 0, 0, 0}}, torch::kFloat)));
937 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 9}));
938
939 x = torch::tensor(
940 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
941 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
942 y = F::max_unpool1d(
943 x,
944 indices,
945 F::MaxUnpool1dFuncOptions(3).output_size(
946 std::vector<int64_t>({1, 1, 9})));
947
948 ASSERT_EQ(y.ndimension(), 3);
949 ASSERT_TRUE(torch::allclose(
950 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
951 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
952
953 x = torch::tensor(
954 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
955 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
956 y = F::max_unpool1d(
957 x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1));
958
959 ASSERT_EQ(y.ndimension(), 3);
960 ASSERT_TRUE(
961 torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
962 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
963 }
964
TEST_F(FunctionalTest,MaxUnpool2d)965 TEST_F(FunctionalTest, MaxUnpool2d) {
966 auto indices = torch::tensor(
967 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
968 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
969 torch::kLong);
970 auto x = torch::tensor(
971 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
972 {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
973 torch::dtype(torch::kFloat).requires_grad(true));
974 auto y = F::max_unpool2d(
975 x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
976
977 ASSERT_EQ(y.dim(), 4);
978 ASSERT_TRUE(torch::allclose(
979 y,
980 torch::tensor(
981 {{{{0, 0, 0, 0, 0},
982 {0, 6, 0, 8, 9},
983 {0, 0, 0, 0, 0},
984 {0, 16, 0, 18, 19},
985 {0, 21, 0, 23, 24}}},
986 {{{0, 0, 0, 0, 0},
987 {0, 31, 0, 33, 34},
988 {0, 0, 0, 0, 0},
989 {0, 41, 0, 43, 44},
990 {0, 46, 0, 48, 49}}}},
991 torch::kFloat)));
992 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
993
994 indices = torch::tensor(
995 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
996 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
997 torch::kLong);
998 x = torch::tensor(
999 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
1000 {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}},
1001 torch::dtype(torch::kFloat).requires_grad(true));
1002 y = F::max_unpool2d(
1003 x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
1004
1005 ASSERT_EQ(y.dim(), 3);
1006 ASSERT_TRUE(torch::allclose(
1007 y,
1008 torch::tensor(
1009 {{{0, 0, 0, 0, 0},
1010 {0, 6, 0, 8, 9},
1011 {0, 0, 0, 0, 0},
1012 {0, 16, 0, 18, 19},
1013 {0, 21, 0, 23, 24}},
1014 {{0, 0, 0, 0, 0},
1015 {0, 31, 0, 33, 34},
1016 {0, 0, 0, 0, 0},
1017 {0, 41, 0, 43, 44},
1018 {0, 46, 0, 48, 49}}},
1019 torch::kFloat)));
1020 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 5, 5}));
1021 }
1022
TEST_F(FunctionalTest,MaxUnpool3d)1023 TEST_F(FunctionalTest, MaxUnpool3d) {
1024 auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1025 auto x = torch::tensor(
1026 {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1027 auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1028
1029 ASSERT_EQ(y.dim(), 5);
1030 ASSERT_TRUE(torch::allclose(
1031 y,
1032 torch::tensor(
1033 {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1034 {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1035 {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1036 torch::kFloat)));
1037 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1038
1039 indices = torch::tensor({{{{26}}}}, torch::kLong);
1040 x = torch::tensor(
1041 {{{{26}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1042 y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1043
1044 ASSERT_EQ(y.dim(), 4);
1045 ASSERT_TRUE(torch::allclose(
1046 y,
1047 torch::tensor(
1048 {{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1049 {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1050 {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}},
1051 torch::kFloat)));
1052 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
1053 }
1054
TEST_F(FunctionalTest,ELU)1055 TEST_F(FunctionalTest, ELU) {
1056 const auto size = 3;
1057 for (const auto inplace : {false, true}) {
1058 for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1059 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1060 x.resize_({size, size, size});
1061 auto x_bf16 =
1062 torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16);
1063 x_bf16.resize_({size, size, size});
1064
1065 auto y_exp = torch::max(torch::zeros_like(x), x) +
1066 torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
1067 auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1068 auto y_bf16 =
1069 F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1070
1071 ASSERT_EQ(y.ndimension(), 3);
1072 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1073 ASSERT_TRUE(torch::allclose(y, y_exp));
1074 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1075 if (inplace) {
1076 ASSERT_TRUE(torch::allclose(x, y_exp));
1077 ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1078 }
1079 }
1080 }
1081 ASSERT_TRUE(F::elu(torch::tensor(1.)).defined());
1082 }
1083
TEST_F(FunctionalTest,SELU)1084 TEST_F(FunctionalTest, SELU) {
1085 {
1086 const double scale = 1.0507009873554804934193349852946;
1087 const double alpha = 1.6732632423543772848170429916717;
1088 for (const auto inplace : {false, true}) {
1089 auto input = torch::randn({5, 5});
1090 auto input_bf16 = input.clone().to(torch::kBFloat16);
1091 auto expected = scale *
1092 (torch::max(torch::zeros_like(input), input) +
1093 torch::min(
1094 torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
1095 auto output = F::selu(input, inplace);
1096 auto output_bf16 = F::selu(input_bf16, inplace);
1097
1098 ASSERT_TRUE(output.allclose(expected));
1099 ASSERT_TRUE(output_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1100 if (inplace) {
1101 ASSERT_TRUE(input.allclose(expected));
1102 ASSERT_TRUE(input_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1103 }
1104 }
1105 }
1106 {
1107 auto input = torch::arange(0, 9, torch::kDouble).view({3, 3});
1108 auto output = F::selu(input);
1109 auto expected = F::selu(input, false);
1110 ASSERT_TRUE(output.allclose(expected));
1111 }
1112 ASSERT_TRUE(F::selu(torch::tensor(1.)).defined());
1113 }
1114
TEST_F(FunctionalTest,GLU)1115 TEST_F(FunctionalTest, GLU) {
1116 int64_t dim = 1;
1117 auto input = torch::randn({4, 2}, torch::requires_grad());
1118 auto output = F::glu(input, dim);
1119 auto input_size = input.sizes()[dim] / 2;
1120 auto first_half = input.narrow(dim, 0, input_size);
1121 auto second_half = input.narrow(dim, input_size, input_size);
1122 auto expected = first_half * torch::sigmoid(second_half);
1123
1124 ASSERT_TRUE(output.allclose(expected));
1125 ASSERT_TRUE(F::glu(input).allclose(expected));
1126 }
1127
TEST_F(FunctionalTest,GELU)1128 TEST_F(FunctionalTest, GELU) {
1129 const auto x = torch::linspace(-3.0, 3.0, 100);
1130 const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
1131 const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none"));
1132 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1133 }
1134
TEST_F(FunctionalTest,TanhGELU)1135 TEST_F(FunctionalTest, TanhGELU) {
1136 const auto x = torch::linspace(-3.0, 3.0, 100);
1137 const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
1138 const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
1139 const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh"));
1140 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1141 }
1142
TEST_F(FunctionalTest,Hardshrink)1143 TEST_F(FunctionalTest, Hardshrink) {
1144 const auto size = 3;
1145 for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
1146 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1147 x.resize_({size, size, size}).set_requires_grad(true);
1148 auto y = F::hardshrink(x, F::HardshrinkFuncOptions().lambda(lambda));
1149 torch::Tensor s = y.sum();
1150
1151 s.backward();
1152 ASSERT_EQ(s.ndimension(), 0);
1153
1154 ASSERT_EQ(y.ndimension(), 3);
1155 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1156 auto y_exp = (x.abs() > lambda) * x;
1157 ASSERT_TRUE(torch::allclose(y, y_exp));
1158 }
1159 }
1160
TEST_F(FunctionalTest,OneHot)1161 TEST_F(FunctionalTest, OneHot) {
1162 { // Test #1
1163 auto x = torch::arange(0, 5, torch::kLong);
1164 auto y = F::one_hot(x % 3);
1165 auto expected = torch::tensor(
1166 {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong);
1167
1168 ASSERT_EQ(y.ndimension(), 2);
1169 ASSERT_TRUE(torch::allclose(y, expected));
1170 ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 3}));
1171 }
1172
1173 { // Test #2
1174 auto x = torch::arange(0, 5, torch::kLong);
1175 auto y = F::one_hot(x % 3, 5);
1176 auto expected = torch::tensor(
1177 {{1, 0, 0, 0, 0},
1178 {0, 1, 0, 0, 0},
1179 {0, 0, 1, 0, 0},
1180 {1, 0, 0, 0, 0},
1181 {0, 1, 0, 0, 0}},
1182 torch::kLong);
1183
1184 ASSERT_EQ(y.ndimension(), 2);
1185 ASSERT_TRUE(torch::allclose(y, expected));
1186 ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 5}));
1187 }
1188
1189 { // Test #3
1190 auto x = torch::arange(0, 6, torch::kLong);
1191 auto y = F::one_hot(x.view(std::vector<int64_t>({3, 2})) % 3);
1192 auto expected = torch::tensor(
1193 {{{1, 0, 0}, {0, 1, 0}},
1194 {{0, 0, 1}, {1, 0, 0}},
1195 {{0, 1, 0}, {0, 0, 1}}},
1196 torch::kLong);
1197
1198 ASSERT_EQ(y.ndimension(), 3);
1199 ASSERT_TRUE(torch::allclose(y, expected));
1200 ASSERT_EQ(y.sizes(), std::vector<int64_t>({3, 2, 3}));
1201 }
1202 }
1203
TEST_F(FunctionalTest,Hardtanh)1204 TEST_F(FunctionalTest, Hardtanh) {
1205 const auto size = 3;
1206 for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
1207 for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) {
1208 for (const auto inplace : {false, true}) {
1209 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1210 x.resize_({size, size, size});
1211 auto y_exp = (x < min_val) * min_val +
1212 ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val;
1213 auto y = F::hardtanh(
1214 x,
1215 F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace(
1216 inplace));
1217
1218 ASSERT_EQ(y.ndimension(), 3);
1219 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1220 ASSERT_TRUE(torch::allclose(y, y_exp));
1221 if (inplace) {
1222 ASSERT_TRUE(torch::allclose(x, y_exp));
1223 }
1224 }
1225 }
1226 }
1227 ASSERT_TRUE(F::hardtanh(torch::tensor(1.)).defined());
1228 }
1229
TEST_F(FunctionalTest,LeakyReLU)1230 TEST_F(FunctionalTest, LeakyReLU) {
1231 const auto size = 3;
1232 for (const auto negative_slope : {0.0, 0.42, 1.0}) {
1233 for (const auto inplace : {false, true}) {
1234 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1235 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1236 x.resize_({size, size, size});
1237 auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
1238 auto y = F::leaky_relu(
1239 x,
1240 F::LeakyReLUFuncOptions()
1241 .negative_slope(negative_slope)
1242 .inplace(inplace));
1243
1244 ASSERT_EQ(y.ndimension(), 3);
1245 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1246 ASSERT_TRUE(torch::allclose(y, y_exp));
1247 if (inplace) {
1248 ASSERT_TRUE(torch::allclose(x, y_exp));
1249 }
1250 }
1251 }
1252 }
1253 ASSERT_TRUE(F::leaky_relu(torch::tensor(1.)).defined());
1254 }
1255
TEST_F(FunctionalTest,LogSigmoid)1256 TEST_F(FunctionalTest, LogSigmoid) {
1257 const auto size = 3;
1258 LogSigmoid model;
1259 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1260 x.resize_({size, size, size});
1261 auto y = F::logsigmoid(x);
1262
1263 ASSERT_EQ(y.ndimension(), 3);
1264 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1265 auto y_exp = torch::log(
1266 torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
1267 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1268 }
1269
TEST_F(FunctionalTest,GumbelSoftmax)1270 TEST_F(FunctionalTest, GumbelSoftmax) {
1271 // Test 1: No-options
1272 {
1273 auto logits = torch::randn({5});
1274 int expected_count = 1;
1275 auto y_draw = F::gumbel_softmax(logits);
1276
1277 // All values positive
1278 ASSERT_GE(y_draw.min().item<int>(), 0);
1279 // Shape unchanged
1280 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1281 // One choice per draw
1282 ASSERT_TRUE(torch::allclose(
1283 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1284 }
1285
1286 // Test 2: 1D shape, 0 and -1 dim
1287 for (const auto dim : {0, -1}) {
1288 auto logits = torch::randn({5});
1289 int expected_count = 1;
1290 auto y_draw = F::gumbel_softmax(
1291 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim));
1292
1293 // All values positive
1294 ASSERT_GE(y_draw.min().item<int>(), 0);
1295 // Shape unchanged
1296 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1297 // One choice per draw
1298 ASSERT_TRUE(torch::allclose(
1299 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1300 }
1301
1302 { // Test 3: 2D shape, 1 dim
1303 auto logits = torch::randn({5, 4});
1304 int expected_count = 5;
1305 auto y_draw = F::gumbel_softmax(
1306 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1));
1307
1308 // All values positive
1309 ASSERT_GE(y_draw.min().item<int>(), 0);
1310 // Shape unchanged
1311 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1312 // One choice per draw
1313 ASSERT_TRUE(torch::allclose(
1314 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1315 }
1316
1317 // Test 4: 3D shape, 1 and -1 dim
1318 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1319 int dims[] = {1, -1};
1320 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers)
1321 int expected[] = {5 * 3, 5 * 4};
1322 for (const auto i : c10::irange(2)) {
1323 auto logits = torch::randn({5, 4, 3});
1324 int expected_count = expected[i];
1325 auto y_draw = F::gumbel_softmax(
1326 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i]));
1327
1328 // All values positive
1329 ASSERT_GE(y_draw.min().item<int>(), 0);
1330 // Shape unchanged
1331 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1332 // One choice per draw
1333 ASSERT_TRUE(torch::allclose(
1334 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1335 }
1336
1337 { // Test 5: Straight through
1338 int num_draws = 100;
1339 auto logits = torch::tensor({{0.2, 0.8, 0.1}});
1340 logits = logits.reshape({1, 3});
1341 logits.requires_grad();
1342 auto probs = logits.softmax(-1);
1343
1344 auto counts = torch::zeros_like(logits);
1345 torch::Tensor y_draw;
1346 for (const auto i : c10::irange(num_draws)) {
1347 (void)i; // Suppress unused variable warning
1348 y_draw =
1349 F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true));
1350 counts += y_draw;
1351 }
1352
1353 // All values positive
1354 ASSERT_GE(y_draw.min().item<int>(), 0);
1355 // Each experiment should result in 1 draw
1356 ASSERT_EQ(counts.sum().item<int>(), num_draws);
1357
1358 // Check results are asymptotically as expected
1359 auto expected = probs * num_draws;
1360 // ~z is approximately N(0,1) for unbiased count
1361 auto z = (counts - expected) / (expected * (1 - probs)).sqrt();
1362 // A (lazy) approximate 99% two-sided test:
1363 // occurs with prob alpha~>=0.01 if unbiased
1364 ASSERT_LT(z.abs().max().item<float>(), 2.58);
1365 }
1366 }
1367
TEST_F(FunctionalTest,Softmax)1368 TEST_F(FunctionalTest, Softmax) {
1369 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1370 // NOLINTNEXTLINE(bugprone-argument-comment)
1371 auto output = F::softmax(input, /*dim=*/1);
1372 auto sum = torch::sum(torch::exp(input), 1);
1373
1374 for (const auto i : c10::irange(2)) {
1375 auto expected = torch::exp(input[i]) / sum[i];
1376 ASSERT_TRUE(torch::allclose(output[i], expected));
1377 }
1378 }
1379
TEST_F(FunctionalTest,Softmin)1380 TEST_F(FunctionalTest, Softmin) {
1381 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1382 // NOLINTNEXTLINE(bugprone-argument-comment)
1383 auto output = F::softmin(input, /*dim=*/1);
1384 auto sum = torch::sum(torch::exp(-input), 1);
1385
1386 for (const auto i : c10::irange(2)) {
1387 auto expected = torch::exp(-input[i]) / sum[i];
1388 ASSERT_TRUE(torch::allclose(output[i], expected));
1389 }
1390 }
1391
TEST_F(FunctionalTest,LogSoftmax)1392 TEST_F(FunctionalTest, LogSoftmax) {
1393 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1394 // NOLINTNEXTLINE(bugprone-argument-comment)
1395 auto output = F::log_softmax(input, /*dim=*/1);
1396 auto sum = torch::sum(torch::exp(input), 1);
1397
1398 for (const auto i : c10::irange(2)) {
1399 auto expected = torch::log(torch::exp(input[i]) / sum[i]);
1400 ASSERT_TRUE(torch::allclose(output[i], expected));
1401 }
1402 }
1403
TEST_F(FunctionalTest,PReLU)1404 TEST_F(FunctionalTest, PReLU) {
1405 const auto x = torch::rand({42, 24}) * 200 - 100;
1406 const auto w = torch::rand(24) * 200 - 100;
1407 const auto y = F::prelu(x, w);
1408 ASSERT_EQ(y.sizes(), std::vector<int64_t>({42, 24}));
1409 const auto y_exp = (x < 0) * w * x + (x >= 0) * x;
1410 ASSERT_TRUE(torch::allclose(y, y_exp));
1411 }
1412
TEST_F(FunctionalTest,LayerNorm)1413 TEST_F(FunctionalTest, LayerNorm) {
1414 const auto input = torch::randn({2, 2});
1415 auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
1416 auto y_exp =
1417 torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5);
1418 ASSERT_TRUE(torch::allclose(y, y_exp));
1419 }
1420
TEST_F(FunctionalTest,GroupNorm)1421 TEST_F(FunctionalTest, GroupNorm) {
1422 const auto input = torch::randn({2, 2});
1423 auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
1424 auto y_exp =
1425 torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5);
1426 ASSERT_TRUE(torch::allclose(y, y_exp));
1427 }
1428
TEST_F(FunctionalTest,LocalResponseNorm)1429 TEST_F(FunctionalTest, LocalResponseNorm) {
1430 const auto x = torch::arange(100, 118).resize_({3, 3, 2});
1431 const auto y = F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
1432 ASSERT_EQ(y.ndimension(), 3);
1433 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
1434 const auto y_exp = torch::tensor(
1435 {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}},
1436 {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}},
1437 {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}},
1438 torch::kFloat);
1439 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1440 }
1441
TEST_F(FunctionalTest,Linear)1442 TEST_F(FunctionalTest, Linear) {
1443 {
1444 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1445 const auto w = torch::arange(200., 206).resize_({3, 2});
1446 const auto b = torch::arange(300., 303);
1447 const auto y = F::linear(x, w, b);
1448 ASSERT_EQ(y.ndimension(), 3);
1449 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1450 const auto y_exp = torch::tensor(
1451 {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
1452 {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
1453 {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
1454 torch::kFloat);
1455 ASSERT_TRUE(torch::allclose(y, y_exp));
1456 }
1457 {
1458 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1459 const auto w = torch::arange(200., 206).resize_({3, 2});
1460 const auto y = F::linear(x, w);
1461 ASSERT_EQ(y.ndimension(), 3);
1462 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1463 const auto y_exp = torch::tensor(
1464 {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
1465 {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
1466 {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
1467 torch::kFloat);
1468 ASSERT_TRUE(torch::allclose(y, y_exp));
1469 }
1470 }
1471
TEST_F(FunctionalTest,Embedding)1472 TEST_F(FunctionalTest, Embedding) {
1473 const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1474 auto weight = torch::empty({10, 3});
1475 torch::nn::init::normal_(weight);
1476 auto y = F::embedding(input, weight);
1477 auto y_exp = torch::embedding(weight, input.contiguous(), -1, false, false);
1478 ASSERT_TRUE(torch::allclose(y, y_exp));
1479 }
1480
TEST_F(FunctionalTest,EmbeddingBag)1481 TEST_F(FunctionalTest, EmbeddingBag) {
1482 const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong);
1483 auto offsets = torch::tensor({0, 4}, torch::kLong);
1484 auto weight = torch::empty({10, 3});
1485 torch::nn::init::normal_(weight);
1486 auto y = F::embedding_bag(
1487 input,
1488 weight,
1489 F::EmbeddingBagFuncOptions()
1490 .mode(torch::kSum)
1491 .offsets(offsets)
1492 .padding_idx(4));
1493 auto y_exp = std::get<0>(torch::embedding_bag(
1494 weight, input, offsets, false, 0, false, torch::Tensor(), false, 4));
1495 ASSERT_TRUE(torch::allclose(y, y_exp));
1496
1497 // no options test
1498 const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1499 auto offsets_ = torch::arange(
1500 0,
1501 input_.numel(),
1502 input_.size(1),
1503 torch::TensorOptions().dtype(torch::kLong).device(input.device()));
1504 y = F::embedding_bag(input_, weight);
1505 y_exp = std::get<0>(torch::embedding_bag(
1506 weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor()));
1507 ASSERT_TRUE(torch::allclose(y, y_exp));
1508 }
1509
TEST_F(FunctionalTest,Bilinear)1510 TEST_F(FunctionalTest, Bilinear) {
1511 auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}});
1512 auto input2 = torch::tensor({{7, 4}, {8, 9}});
1513 auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}});
1514 auto bias = torch::tensor({1});
1515
1516 auto y_with_bias = F::bilinear(input1, input2, weight, bias);
1517 ASSERT_EQ(y_with_bias.ndimension(), 2);
1518 ASSERT_EQ(y_with_bias.sizes(), torch::IntArrayRef({2, 1}));
1519 auto y_with_bias_exp = torch::tensor({{449}, {1702}}).reshape({2, 1});
1520 ASSERT_TRUE(torch::allclose(y_with_bias, y_with_bias_exp, 1e-4, 1e-7));
1521
1522 auto y_no_bias = F::bilinear(input1, input2, weight);
1523 ASSERT_EQ(y_no_bias.ndimension(), 2);
1524 ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1}));
1525 auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1});
1526 ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7));
1527
1528 input1 = input1.to(torch::kFloat64);
1529 input2 = input2.to(torch::kInt32);
1530 weight = weight.to(torch::kInt32);
1531 ASSERT_THROWS_WITH(
1532 F::bilinear(input1, input2, weight),
1533 "All tensors must have the same dtype, got input1: double, input2: int, weight: int");
1534 }
1535
TEST_F(FunctionalTest,Normalize)1536 TEST_F(FunctionalTest, Normalize) {
1537 const auto expected = torch::tensor(
1538 {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000},
1539 {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}},
1540 torch::requires_grad().dtype(torch::kFloat));
1541 { // Test #1
1542 auto input = torch::tensor(
1543 {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}},
1544 torch::dtype(torch::kFloat).requires_grad(true));
1545 auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1546
1547 // reduce to scalar to call .backward()
1548 torch::Tensor s = norm.sum();
1549 s.backward();
1550
1551 ASSERT_EQ(s.ndimension(), 0);
1552 ASSERT_EQ(input.grad().numel(), 10);
1553 ASSERT_TRUE(torch::allclose(norm, expected));
1554 }
1555
1556 { // Test #2 Check variations of optional arguments
1557 auto input = torch::tensor(
1558 {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat));
1559 auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat));
1560 // non-null output argument
1561 F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output));
1562 // default options
1563 F::normalize(input);
1564
1565 ASSERT_TRUE(torch::allclose(output, expected));
1566 }
1567
1568 { // Test #3 Base case of scalar tensor
1569 auto input = torch::randn({}, torch::requires_grad());
1570 torch::Tensor norm =
1571 F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1572 norm.backward();
1573
1574 ASSERT_EQ(input.grad().numel(), 1);
1575 }
1576 }
1577
TEST_F(FunctionalTest,ReLU)1578 TEST_F(FunctionalTest, ReLU) {
1579 const auto size = 3;
1580 for (const auto inplace : {false, true}) {
1581 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1582 x.resize_({size, size, size});
1583 auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1584 auto y = F::relu(x, F::ReLUFuncOptions().inplace(inplace));
1585
1586 ASSERT_EQ(y.ndimension(), 3);
1587 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1588 ASSERT_TRUE(torch::allclose(y, y_exp));
1589 if (inplace) {
1590 ASSERT_TRUE(torch::allclose(x, y_exp));
1591 }
1592
1593 // NOLINTNEXTLINE(bugprone-argument-comment)
1594 y = F::relu(x, /*inplace=*/inplace);
1595
1596 ASSERT_EQ(y.ndimension(), 3);
1597 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1598 ASSERT_TRUE(torch::allclose(y, y_exp));
1599 if (inplace) {
1600 ASSERT_TRUE(torch::allclose(x, y_exp));
1601 }
1602 }
1603 ASSERT_TRUE(F::relu(torch::tensor(1.)).defined());
1604 }
1605
TEST_F(FunctionalTest,ReLUDefaultOptions)1606 TEST_F(FunctionalTest, ReLUDefaultOptions) {
1607 const auto size = 3;
1608 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1609 x.resize_({size, size, size});
1610 auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1611 auto y = F::relu(x);
1612
1613 ASSERT_EQ(y.ndimension(), 3);
1614 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1615 ASSERT_TRUE(torch::allclose(y, y_exp));
1616 }
1617
TEST_F(FunctionalTest,ReLU6)1618 TEST_F(FunctionalTest, ReLU6) {
1619 const auto size = 3;
1620 for (const auto inplace : {false, true}) {
1621 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1622 x.resize_({size, size, size});
1623 auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1624 auto y = F::relu6(x, F::ReLU6FuncOptions().inplace(inplace));
1625
1626 ASSERT_EQ(y.ndimension(), 3);
1627 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1628 ASSERT_TRUE(torch::allclose(y, y_exp));
1629 if (inplace) {
1630 ASSERT_TRUE(torch::allclose(x, y_exp));
1631 }
1632
1633 // NOLINTNEXTLINE(bugprone-argument-comment)
1634 y = F::relu6(x, /*inplace=*/inplace);
1635
1636 ASSERT_EQ(y.ndimension(), 3);
1637 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1638 ASSERT_TRUE(torch::allclose(y, y_exp));
1639 if (inplace) {
1640 ASSERT_TRUE(torch::allclose(x, y_exp));
1641 }
1642 }
1643 ASSERT_TRUE(F::relu6(torch::tensor(1.)).defined());
1644 }
1645
TEST_F(FunctionalTest,ReLU6DefaultOptions)1646 TEST_F(FunctionalTest, ReLU6DefaultOptions) {
1647 const auto size = 3;
1648 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1649 x.resize_({size, size, size});
1650 auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1651 auto y = F::relu6(x);
1652
1653 ASSERT_EQ(y.ndimension(), 3);
1654 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1655 ASSERT_TRUE(torch::allclose(y, y_exp));
1656 }
1657
TEST_F(FunctionalTest,RReLU)1658 TEST_F(FunctionalTest, RReLU) {
1659 const auto size = 3;
1660 for (const auto lower : {0.01, 0.1, 0.2}) {
1661 for (const auto upper : {0.3, 0.4, 0.5}) {
1662 for (const auto inplace : {false, true}) {
1663 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1664 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1665 x.resize_({size, size, size});
1666 auto x_copy = x.clone();
1667 auto y = F::rrelu(
1668 x,
1669 F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace));
1670 auto z =
1671 ((x_copy >= 0) * (x_copy == y) +
1672 (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1673 1.0;
1674
1675 ASSERT_EQ(y.ndimension(), 3);
1676 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1677 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1678 if (inplace) {
1679 ASSERT_TRUE(torch::allclose(x, y));
1680 }
1681 }
1682 }
1683 }
1684 }
1685 ASSERT_TRUE(F::rrelu(torch::tensor(1.)).defined());
1686 }
1687
TEST_F(FunctionalTest,RReLUDefaultOptions)1688 TEST_F(FunctionalTest, RReLUDefaultOptions) {
1689 const auto size = 3;
1690 const auto lower = 1.0 / 8.0;
1691 const auto upper = 1.0 / 3.0;
1692 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1693 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1694 x.resize_({size, size, size});
1695 auto x_copy = x.clone();
1696 auto y = F::rrelu(x);
1697 auto z = ((x_copy >= 0) * (x_copy == y) +
1698 (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1699 1.0;
1700
1701 ASSERT_EQ(y.ndimension(), 3);
1702 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1703 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1704 }
1705 }
1706
TEST_F(FunctionalTest,CELU)1707 TEST_F(FunctionalTest, CELU) {
1708 const auto size = 3;
1709 for (const auto inplace : {false, true}) {
1710 for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
1711 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1712 x.resize_({size, size, size});
1713 auto x_bf16 = x.clone().to(torch::kBFloat16);
1714 auto y_exp = torch::max(torch::zeros_like(x), x) +
1715 torch::min(torch::zeros_like(x),
1716 alpha * (torch::exp(x / alpha) - 1.0));
1717 auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1718 auto y_bf16 =
1719 F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1720
1721 ASSERT_EQ(y.ndimension(), 3);
1722 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1723 ASSERT_TRUE(torch::allclose(y, y_exp));
1724 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1725 if (inplace) {
1726 ASSERT_TRUE(torch::allclose(x, y_exp));
1727 ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1728 }
1729 }
1730 }
1731 ASSERT_TRUE(F::celu(torch::tensor(1.)).defined());
1732 }
1733
TEST_F(FunctionalTest,CELUDefaultOptions)1734 TEST_F(FunctionalTest, CELUDefaultOptions) {
1735 const auto size = 3;
1736 const auto alpha = 1.0;
1737 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1738 x.resize_({size, size, size});
1739 auto x_bf16 = x.clone().to(torch::kBFloat16);
1740 auto y_exp = torch::max(torch::zeros_like(x), x) +
1741 torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
1742 auto y = F::celu(x);
1743 auto y_bf16 = F::celu(x_bf16);
1744
1745 ASSERT_EQ(y.ndimension(), 3);
1746 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1747 ASSERT_TRUE(torch::allclose(y, y_exp));
1748 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1749 }
1750
TEST_F(FunctionalTest,PixelShuffle)1751 TEST_F(FunctionalTest, PixelShuffle) {
1752 auto x = torch::tensor(
1753 {{{{-17, 19}, {-1, 2}},
1754 {{7, 14}, {-3, 1}},
1755 {{0, -2}, {-12, 14}},
1756 {{-15, 0}, {-3, 9}}}},
1757 torch::kFloat);
1758 auto y_exp = torch::tensor(
1759 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1760 torch::kFloat);
1761 auto y = F::pixel_shuffle(x, 2);
1762
1763 ASSERT_EQ(y.ndimension(), 4);
1764 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
1765 ASSERT_TRUE(y.allclose(y_exp));
1766 }
1767
TEST_F(FunctionalTest,PixelUnshuffle)1768 TEST_F(FunctionalTest, PixelUnshuffle) {
1769 auto x = torch::tensor(
1770 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1771 torch::kFloat);
1772 auto y_exp = torch::tensor(
1773 {{{{-17, 19}, {-1, 2}},
1774 {{7, 14}, {-3, 1}},
1775 {{0, -2}, {-12, 14}},
1776 {{-15, 0}, {-3, 9}}}},
1777 torch::kFloat);
1778 auto y = F::pixel_unshuffle(x, 2);
1779
1780 ASSERT_EQ(y.ndimension(), 4);
1781 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
1782 ASSERT_TRUE(y.allclose(y_exp));
1783 }
1784
TEST_F(FunctionalTest,Softplus)1785 TEST_F(FunctionalTest, Softplus) {
1786 const auto size = 3;
1787 for (const auto beta : {0.5, 1.0, 2.0}) {
1788 for (const auto threshold : {1.0, 3.0, 5.0}) {
1789 auto x = torch::linspace(-3.0, 3.0, 61);
1790 x.resize_({size, size, size});
1791 auto y_exp =
1792 (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1793 (x > threshold) * x;
1794 auto y = F::softplus(
1795 x, F::SoftplusFuncOptions().beta(beta).threshold(threshold));
1796
1797 ASSERT_EQ(y.ndimension(), 3);
1798 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1799 ASSERT_TRUE(torch::allclose(y, y_exp));
1800 }
1801 }
1802 }
1803
TEST_F(FunctionalTest,SoftplusDefaultOptions)1804 TEST_F(FunctionalTest, SoftplusDefaultOptions) {
1805 const auto size = 3;
1806 const auto beta = 1.0;
1807 const auto threshold = 20.0;
1808 auto x = torch::linspace(-3.0, 3.0, 61);
1809 x.resize_({size, size, size});
1810 auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1811 (x > threshold) * x;
1812 auto y = F::softplus(x);
1813
1814 ASSERT_EQ(y.ndimension(), 3);
1815 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1816 ASSERT_TRUE(torch::allclose(y, y_exp));
1817 }
1818
TEST_F(FunctionalTest,Fold)1819 TEST_F(FunctionalTest, Fold) {
1820 auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::kDouble);
1821 auto output = F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
1822 auto expected = torch::tensor(
1823 {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1824 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1825 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1826 torch::kDouble);
1827
1828 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1829 ASSERT_TRUE(output.allclose(expected));
1830 }
1831
TEST_F(FunctionalTest,Unfold)1832 TEST_F(FunctionalTest, Unfold) {
1833 auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3});
1834 auto output =
1835 F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
1836 auto expected = torch::tensor(
1837 {{{0.0, 0.0, 0.0, 4.0},
1838 {0.0, 0.0, 3.0, 5.0},
1839 {0.0, 1.0, 0.0, 0.0},
1840 {0.0, 2.0, 0.0, 0.0},
1841 {0.0, 0.0, 0.0, 10.0},
1842 {0.0, 0.0, 9.0, 11.0},
1843 {0.0, 7.0, 0.0, 0.0},
1844 {6.0, 8.0, 0.0, 0.0}}},
1845 torch::kDouble);
1846
1847 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1848 ASSERT_TRUE(output.allclose(expected));
1849 }
1850
TEST_F(FunctionalTest,Softshrink)1851 TEST_F(FunctionalTest, Softshrink) {
1852 const auto size = 3;
1853 for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1854 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1855 x.resize_({size, size, size}).set_requires_grad(true);
1856 // NOLINTNEXTLINE(bugprone-argument-comment)
1857 auto y = F::softshrink(x, /*lambda=*/lambda);
1858 torch::Tensor s = y.sum();
1859
1860 s.backward();
1861 ASSERT_EQ(s.ndimension(), 0);
1862
1863 ASSERT_EQ(y.ndimension(), 3);
1864 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1865 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1866 ASSERT_TRUE(torch::allclose(y, y_exp));
1867 }
1868 }
1869
TEST_F(FunctionalTest,SoftshrinkDefaultOptions)1870 TEST_F(FunctionalTest, SoftshrinkDefaultOptions) {
1871 const auto size = 3;
1872 const auto lambda = 0.5;
1873 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1874 x.resize_({size, size, size}).set_requires_grad(true);
1875 auto y = F::softshrink(x);
1876 torch::Tensor s = y.sum();
1877
1878 s.backward();
1879 ASSERT_EQ(s.ndimension(), 0);
1880
1881 ASSERT_EQ(y.ndimension(), 3);
1882 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1883 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1884 }
1885
TEST_F(FunctionalTest,Softsign)1886 TEST_F(FunctionalTest, Softsign) {
1887 auto x = torch::randn(100) * 10;
1888 auto y_exp = x / (1 + x.abs());
1889 auto y = F::softsign(x);
1890
1891 ASSERT_TRUE(torch::allclose(y, y_exp));
1892 }
1893
TEST_F(FunctionalTest,Mish)1894 TEST_F(FunctionalTest, Mish) {
1895 auto x = torch::randn(100) * 10;
1896 auto y_exp = x * x.exp().log1p().tanh();
1897 auto y = F::mish(x);
1898
1899 ASSERT_TRUE(torch::allclose(y, y_exp));
1900 }
1901
TEST_F(FunctionalTest,Tanhshrink)1902 TEST_F(FunctionalTest, Tanhshrink) {
1903 auto x = torch::randn(100) * 10;
1904 auto y_exp = x - x.tanh();
1905 auto y = F::tanhshrink(x);
1906
1907 ASSERT_TRUE(torch::allclose(y, y_exp));
1908 }
1909
TEST_F(FunctionalTest,Threshold)1910 TEST_F(FunctionalTest, Threshold) {
1911 const auto size = 3;
1912 for (const auto threshold : {0.5, 1.0, 2.0}) {
1913 for (const auto value : {0.5, 1.0, 2.0}) {
1914 for (const auto inplace : {false, true}) {
1915 auto x = torch::linspace(-3.0, 3.0, 61);
1916 x.resize_({size, size, size});
1917 auto y_exp = (x <= threshold) * value + (x > threshold) * x;
1918 auto y = F::threshold(
1919 x, F::ThresholdFuncOptions(threshold, value).inplace(inplace));
1920
1921 ASSERT_EQ(y.ndimension(), 3);
1922 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1923 ASSERT_TRUE(torch::allclose(y, y_exp));
1924 if (inplace) {
1925 ASSERT_TRUE(torch::allclose(x, y_exp));
1926 }
1927 }
1928 }
1929 }
1930 ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5))
1931 .defined());
1932 }
1933
TEST_F(FunctionalTest,BatchNorm1d)1934 TEST_F(FunctionalTest, BatchNorm1d) {
1935 int num_features = 5;
1936 double eps = 1e-05;
1937 double momentum = 0.1;
1938
1939 auto input = torch::randn({2, 5});
1940 auto mean = torch::randn(5);
1941 auto variance = torch::rand(5);
1942 auto weight = torch::ones({num_features});
1943 auto bias = torch::zeros({num_features});
1944 auto output = F::batch_norm(
1945 input,
1946 mean,
1947 variance,
1948 F::BatchNormFuncOptions()
1949 .weight(weight)
1950 .bias(bias)
1951 .momentum(momentum)
1952 .eps(eps)
1953 .training(false));
1954 auto expected = (input - mean) / torch::sqrt(variance + eps);
1955 ASSERT_TRUE(output.allclose(expected));
1956 }
1957
TEST_F(FunctionalTest,BatchNorm1dDefaultOptions)1958 TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
1959 auto input = torch::randn({2, 5});
1960 auto mean = torch::randn(5);
1961 auto variance = torch::rand(5);
1962 auto output = F::batch_norm(input, mean, variance);
1963 auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
1964 ASSERT_TRUE(output.allclose(expected));
1965 }
1966
TEST_F(FunctionalTest,BatchNorm2d)1967 TEST_F(FunctionalTest, BatchNorm2d) {
1968 int num_features = 5;
1969 double eps = 1e-05;
1970 double momentum = 0.1;
1971
1972 auto input = torch::randn({2, num_features, 4, 4});
1973 auto mean = torch::randn(num_features);
1974 auto variance = torch::rand(num_features);
1975 auto weight = torch::ones({num_features});
1976 auto bias = torch::zeros({num_features});
1977 auto output = F::batch_norm(
1978 input,
1979 mean,
1980 variance,
1981 F::BatchNormFuncOptions()
1982 .weight(weight)
1983 .bias(bias)
1984 .momentum(momentum)
1985 .eps(eps)
1986 .training(false));
1987 auto expected = torch::transpose(
1988 (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
1989 1,
1990 3);
1991 ASSERT_TRUE(output.allclose(expected));
1992 }
1993
TEST_F(FunctionalTest,BatchNorm2dDefaultOptions)1994 TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) {
1995 int num_features = 5;
1996 double eps = 1e-05;
1997
1998 auto input = torch::randn({2, num_features, 4, 4});
1999 auto mean = torch::randn(num_features);
2000 auto variance = torch::rand(num_features);
2001 auto output = F::batch_norm(input, mean, variance);
2002 auto expected = torch::transpose(
2003 (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
2004 1,
2005 3);
2006 ASSERT_TRUE(output.allclose(expected));
2007 }
2008
TEST_F(FunctionalTest,BatchNorm3d)2009 TEST_F(FunctionalTest, BatchNorm3d) {
2010 int num_features = 5;
2011 double eps = 1e-05;
2012 double momentum = 0.1;
2013
2014 auto input = torch::randn({2, num_features, 2, 2, 2});
2015 auto mean = torch::randn(num_features);
2016 auto variance = torch::rand(num_features);
2017 auto weight = torch::ones({num_features});
2018 auto bias = torch::zeros({num_features});
2019 auto output = F::batch_norm(
2020 input,
2021 mean,
2022 variance,
2023 F::BatchNormFuncOptions()
2024 .weight(weight)
2025 .bias(bias)
2026 .momentum(momentum)
2027 .eps(eps)
2028 .training(false));
2029 auto expected = torch::transpose(
2030 (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2031 1,
2032 4);
2033 ASSERT_TRUE(output.allclose(expected));
2034 }
2035
TEST_F(FunctionalTest,BatchNorm3dDefaultOptions)2036 TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
2037 int num_features = 5;
2038 double eps = 1e-05;
2039
2040 auto input = torch::randn({2, num_features, 2, 2, 2});
2041 auto mean = torch::randn(num_features);
2042 auto variance = torch::rand(num_features);
2043 auto output = F::batch_norm(input, mean, variance);
2044 auto expected = torch::transpose(
2045 (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2046 1,
2047 4);
2048 ASSERT_TRUE(output.allclose(expected));
2049 }
2050
TEST_F(FunctionalTest,InstanceNorm1d)2051 TEST_F(FunctionalTest, InstanceNorm1d) {
2052 int num_features = 5;
2053 double eps = 1e-05;
2054 double momentum = 0.1;
2055
2056 auto input = torch::arange(40.).view({2, 5, 4});
2057 auto mean = torch::arange(5.);
2058 auto variance = torch::arange(5.);
2059 auto weight = torch::arange((double)num_features);
2060 auto bias = torch::arange((double)num_features);
2061 auto output = F::instance_norm(
2062 input,
2063 F::InstanceNormFuncOptions()
2064 .running_mean(mean)
2065 .running_var(variance)
2066 .weight(weight)
2067 .bias(bias)
2068 .momentum(momentum)
2069 .eps(eps));
2070 auto expected = torch::tensor(
2071 {{{0.0000, 0.0000, 0.0000, 0.0000},
2072 {-0.3416, 0.5528, 1.4472, 2.3416},
2073 {-0.6833, 1.1056, 2.8944, 4.6833},
2074 {-1.0249, 1.6584, 4.3416, 7.0249},
2075 {-1.3665, 2.2112, 5.7888, 9.3665}},
2076 {{0.0000, 0.0000, 0.0000, 0.0000},
2077 {-0.3416, 0.5528, 1.4472, 2.3416},
2078 {-0.6833, 1.1056, 2.8944, 4.6833},
2079 {-1.0249, 1.6584, 4.3416, 7.0249},
2080 {-1.3665, 2.2112, 5.7888, 9.3665}}});
2081 ASSERT_TRUE(output.allclose(expected, 2e-04));
2082 }
2083
TEST_F(FunctionalTest,InstanceNorm1dDefaultOptions)2084 TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
2085 auto input = torch::arange(40.).view({2, 5, 4});
2086 auto output = F::instance_norm(input);
2087 auto expected = torch::tensor(
2088 {{{-1.3416, -0.4472, 0.4472, 1.3416},
2089 {-1.3416, -0.4472, 0.4472, 1.3416},
2090 {-1.3416, -0.4472, 0.4472, 1.3416},
2091 {-1.3416, -0.4472, 0.4472, 1.3416},
2092 {-1.3416, -0.4472, 0.4472, 1.3416}},
2093 {{-1.3416, -0.4472, 0.4472, 1.3416},
2094 {-1.3416, -0.4472, 0.4472, 1.3416},
2095 {-1.3416, -0.4472, 0.4472, 1.3416},
2096 {-1.3416, -0.4472, 0.4472, 1.3416},
2097 {-1.3416, -0.4472, 0.4472, 1.3416}}});
2098 ASSERT_TRUE(output.allclose(expected, 2e-04));
2099 }
2100
TEST_F(FunctionalTest,InstanceNorm2d)2101 TEST_F(FunctionalTest, InstanceNorm2d) {
2102 int num_features = 5;
2103 double eps = 1e-05;
2104 double momentum = 0.1;
2105
2106 auto input =
2107 torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2108 auto mean = torch::arange((double)num_features);
2109 auto variance = torch::arange((double)num_features);
2110 auto weight = torch::arange((double)num_features);
2111 auto bias = torch::arange((double)num_features);
2112 auto output = F::instance_norm(
2113 input,
2114 F::InstanceNormFuncOptions()
2115 .running_mean(mean)
2116 .running_var(variance)
2117 .weight(weight)
2118 .bias(bias)
2119 .momentum(momentum)
2120 .eps(eps));
2121 auto expected = torch::tensor(
2122 {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2123 {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2124 {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2125 {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2126 {{-1.3665, 2.2112}, {5.7888, 9.3665}}},
2127 {{{0.0000, 0.0000}, {0.0000, 0.0000}},
2128 {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2129 {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2130 {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2131 {{-1.3665, 2.2112}, {5.7888, 9.3665}}}});
2132 ASSERT_TRUE(output.allclose(expected, 2e-04));
2133 }
2134
TEST_F(FunctionalTest,InstanceNorm2dDefaultOptions)2135 TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
2136 int num_features = 5;
2137
2138 auto input =
2139 torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2140 auto output = F::instance_norm(input);
2141 auto expected = torch::tensor(
2142 {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2143 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2144 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2145 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2146 {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
2147 {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2148 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2149 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2150 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2151 {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
2152 ASSERT_TRUE(output.allclose(expected, 2e-04));
2153 }
2154
TEST_F(FunctionalTest,InstanceNorm3d)2155 TEST_F(FunctionalTest, InstanceNorm3d) {
2156 int num_features = 5;
2157 double eps = 1e-05;
2158 double momentum = 0.1;
2159
2160 auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2161 .view({2, num_features, 2, 2, 2});
2162 auto mean = torch::arange((double)num_features);
2163 auto variance = torch::arange((double)num_features);
2164 auto weight = torch::arange((double)num_features);
2165 auto bias = torch::arange((double)num_features);
2166 auto output = F::instance_norm(
2167 input,
2168 F::InstanceNormFuncOptions()
2169 .running_mean(mean)
2170 .running_var(variance)
2171 .weight(weight)
2172 .bias(bias)
2173 .momentum(momentum)
2174 .eps(eps));
2175 auto expected = torch::tensor(
2176 {{{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2177 {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2178 {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2179 {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2180 {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2181 {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2182 {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2183 {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2184 {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2185 {{4.8729, 6.6186}, {8.3644, 10.1101}}}},
2186 {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2187 {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2188 {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2189 {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2190 {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2191 {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2192 {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2193 {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2194 {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2195 {{4.8729, 6.6186}, {8.3644, 10.1101}}}}});
2196 ASSERT_TRUE(output.allclose(expected, 2e-04));
2197 }
2198
TEST_F(FunctionalTest,InstanceNorm3dDefaultOptions)2199 TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
2200 int num_features = 5;
2201
2202 auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2203 .view({2, num_features, 2, 2, 2});
2204 auto output = F::instance_norm(input);
2205 auto expected = torch::tensor(
2206 {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2207 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2208 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2209 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2210 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2211 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2212 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2213 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2214 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2215 {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
2216 {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2217 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2218 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2219 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2220 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2221 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2222 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2223 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2224 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2225 {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
2226 ASSERT_TRUE(output.allclose(expected, 2e-04));
2227 }
2228
TEST_F(FunctionalTest,Interpolate)2229 TEST_F(FunctionalTest, Interpolate) {
2230 {
2231 // 1D interpolation
2232 auto input = torch::ones({1, 1, 2});
2233 auto options = F::InterpolateFuncOptions()
2234 .size(std::vector<int64_t>({4}))
2235 .mode(torch::kNearest);
2236 auto output = F::interpolate(input, options);
2237 auto expected = torch::ones({1, 1, 4});
2238
2239 ASSERT_TRUE(output.allclose(expected));
2240 }
2241 {
2242 // 2D interpolation
2243 for (const auto align_corners : {true, false}) {
2244 // test float scale factor up & down sampling
2245 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2246 auto input = torch::ones({1, 1, 2, 2});
2247 auto options =
2248 F::InterpolateFuncOptions()
2249 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
2250 .mode(torch::kBilinear)
2251 .align_corners(align_corners);
2252 auto output = F::interpolate(input, options);
2253 auto expected_size =
2254 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2255 auto expected = torch::ones({1, 1, expected_size, expected_size});
2256
2257 ASSERT_TRUE(output.allclose(expected));
2258 }
2259 }
2260 }
2261 {
2262 // 3D interpolation
2263 for (const auto align_corners : {true, false}) {
2264 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2265 auto input = torch::ones({1, 1, 2, 2, 2});
2266 auto options = F::InterpolateFuncOptions()
2267 .scale_factor(std::vector<double>(
2268 {scale_factor, scale_factor, scale_factor}))
2269 .mode(torch::kTrilinear)
2270 .align_corners(align_corners);
2271 auto output = F::interpolate(input, options);
2272 auto expected_size =
2273 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2274 auto expected =
2275 torch::ones({1, 1, expected_size, expected_size, expected_size});
2276
2277 ASSERT_TRUE(output.allclose(expected));
2278 }
2279 }
2280 }
2281 {
2282 ASSERT_THROWS_WITH(
2283 F::interpolate(
2284 torch::randn({1}),
2285 F::InterpolateFuncOptions().size(std::vector<int64_t>({1}))),
2286 "Input Error: Only 3D, 4D and 5D input Tensors supported (got 1D) ");
2287 }
2288 {
2289 auto input = torch::randn({3, 2, 2});
2290 ASSERT_THROWS_WITH(
2291 F::interpolate(
2292 input[0],
2293 F::InterpolateFuncOptions().size(std::vector<int64_t>({4, 4}))),
2294 "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) "
2295 "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2296 ASSERT_THROWS_WITH(
2297 F::interpolate(
2298 torch::reshape(input, {1, 1, 1, 3, 2, 2}),
2299 F::InterpolateFuncOptions().size(
2300 std::vector<int64_t>({1, 1, 1, 3, 4, 4}))),
2301 "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) "
2302 "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2303 ASSERT_THROWS_WITH(
2304 F::interpolate(input, F::InterpolateFuncOptions()),
2305 "either size or scale_factor should be defined");
2306 ASSERT_THROWS_WITH(
2307 F::interpolate(
2308 input,
2309 F::InterpolateFuncOptions()
2310 .size(std::vector<int64_t>({3, 4, 4}))
2311 .scale_factor(std::vector<double>({0.5}))),
2312 "only one of size or scale_factor should be defined");
2313 ASSERT_THROWS_WITH(
2314 F::interpolate(
2315 input,
2316 F::InterpolateFuncOptions().scale_factor(
2317 std::vector<double>({3, 2}))),
2318 "scale_factor shape must match input shape. "
2319 "Input is 1D, scale_factor size is [3, 2]");
2320 ASSERT_THROWS_WITH(
2321 F::interpolate(
2322 input,
2323 F::InterpolateFuncOptions()
2324 .mode(torch::kNearest)
2325 .align_corners(true)),
2326 "align_corners option can only be set with the "
2327 "interpolating modes: linear | bilinear | bicubic | trilinear");
2328 }
2329 {
2330 auto tensor = torch::rand({2, 3, 32, 32});
2331 std::vector<int64_t> osize = {8, 10};
2332 auto expected =
2333 at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt);
2334
2335 auto options = F::InterpolateFuncOptions()
2336 .size(osize)
2337 .mode(torch::kNearestExact)
2338 .align_corners(false);
2339 auto output = F::interpolate(tensor, options);
2340
2341 ASSERT_TRUE(output.allclose(expected));
2342 }
2343 {
2344 auto tensor = torch::rand({2, 3, 32, 32});
2345 std::vector<int64_t> osize = {8, 10};
2346 auto expected = at::native::_upsample_bilinear2d_aa(
2347 tensor, osize, false, torch::nullopt);
2348
2349 auto options = F::InterpolateFuncOptions()
2350 .size(osize)
2351 .mode(torch::kBilinear)
2352 .align_corners(false)
2353 .antialias(true);
2354 auto output = F::interpolate(tensor, options);
2355 ASSERT_TRUE(output.allclose(expected));
2356 }
2357 {
2358 auto tensor = torch::rand({2, 3, 32, 32});
2359 std::vector<int64_t> osize = {8, 10};
2360 auto expected = at::native::_upsample_bicubic2d_aa(
2361 tensor, osize, false, torch::nullopt);
2362
2363 auto options = F::InterpolateFuncOptions()
2364 .size(osize)
2365 .mode(torch::kBicubic)
2366 .align_corners(false)
2367 .antialias(true);
2368 auto output = F::interpolate(tensor, options);
2369 ASSERT_TRUE(output.allclose(expected));
2370 }
2371 }
2372
TEST_F(FunctionalTest,Pad1)2373 TEST_F(FunctionalTest, Pad1) {
2374 {
2375 auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
2376 auto output =
2377 F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular));
2378 auto expected = torch::tensor(
2379 {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
2380 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
2381 ASSERT_TRUE(output.allclose(expected, 1e-04));
2382 }
2383 }
TEST_F(FunctionalTest,Pad2)2384 TEST_F(FunctionalTest, Pad2) {
2385 {
2386 auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
2387 auto output =
2388 F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular));
2389 auto expected = torch::tensor(
2390 {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2391 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2392 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2393 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2394 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2395 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2396 {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}},
2397 torch::kDouble);
2398 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
2399 ASSERT_TRUE(output.allclose(expected, 1e-04));
2400 }
2401 }
TEST_F(FunctionalTest,Pad3)2402 TEST_F(FunctionalTest, Pad3) {
2403 {
2404 auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2405 auto output = F::pad(
2406 input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
2407 auto expected = torch::tensor(
2408 {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2409 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2410 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2411 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2412 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2413
2414 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2415 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2416 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2417 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2418 {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2419
2420 {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2421 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2422 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2423 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2424 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2425
2426 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2427 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2428 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2429 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2430 {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2431
2432 {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2433 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2434 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2435 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2436 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2437
2438 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2439 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2440 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2441 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2442 {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}},
2443 torch::kDouble);
2444 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
2445 ASSERT_TRUE(output.allclose(expected, 1e-04));
2446 }
2447 }
TEST_F(FunctionalTest,Pad4)2448 TEST_F(FunctionalTest, Pad4) {
2449 {
2450 auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
2451 auto output =
2452 F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect));
2453 auto expected = torch::tensor(
2454 {{{{3., 2., 3., 2.},
2455 {1., 0., 1., 0.},
2456 {3., 2., 3., 2.},
2457 {1., 0., 1., 0.}},
2458
2459 {{7., 6., 7., 6.},
2460 {5., 4., 5., 4.},
2461 {7., 6., 7., 6.},
2462 {5., 4., 5., 4.}}},
2463
2464 {{{11., 10., 11., 10.},
2465 {9., 8., 9., 8.},
2466 {11., 10., 11., 10.},
2467 {9., 8., 9., 8.}},
2468
2469 {{15., 14., 15., 14.},
2470 {13., 12., 13., 12.},
2471 {15., 14., 15., 14.},
2472 {13., 12., 13., 12.}}}},
2473 torch::kDouble);
2474 ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
2475 ASSERT_TRUE(output.allclose(expected, 1e-04));
2476 }
2477 }
TEST_F(FunctionalTest,Pad5)2478 TEST_F(FunctionalTest, Pad5) {
2479 {
2480 auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2481 auto output = F::pad(
2482 input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
2483 auto expected = torch::tensor(
2484 {{{{{0., 0., 1., 2., 2., 2.},
2485 {0., 0., 1., 2., 2., 2.},
2486 {0., 0., 1., 2., 2., 2.},
2487 {3., 3., 4., 5., 5., 5.},
2488 {3., 3., 4., 5., 5., 5.}},
2489
2490 {{0., 0., 1., 2., 2., 2.},
2491 {0., 0., 1., 2., 2., 2.},
2492 {0., 0., 1., 2., 2., 2.},
2493 {3., 3., 4., 5., 5., 5.},
2494 {3., 3., 4., 5., 5., 5.}},
2495
2496 {{6., 6., 7., 8., 8., 8.},
2497 {6., 6., 7., 8., 8., 8.},
2498 {6., 6., 7., 8., 8., 8.},
2499 {9., 9., 10., 11., 11., 11.},
2500 {9., 9., 10., 11., 11., 11.}},
2501
2502 {{6., 6., 7., 8., 8., 8.},
2503 {6., 6., 7., 8., 8., 8.},
2504 {6., 6., 7., 8., 8., 8.},
2505 {9., 9., 10., 11., 11., 11.},
2506 {9., 9., 10., 11., 11., 11.}},
2507
2508 {{6., 6., 7., 8., 8., 8.},
2509 {6., 6., 7., 8., 8., 8.},
2510 {6., 6., 7., 8., 8., 8.},
2511 {9., 9., 10., 11., 11., 11.},
2512 {9., 9., 10., 11., 11., 11.}}}}},
2513 torch::kDouble);
2514 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
2515 ASSERT_TRUE(output.allclose(expected, 1e-04));
2516 }
2517 }
TEST_F(FunctionalTest,Pad6)2518 TEST_F(FunctionalTest, Pad6) {
2519 {
2520 auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3});
2521 auto output = F::pad(
2522 input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect));
2523 auto expected = torch::tensor(
2524 {{{{{9., 10., 11., 10., 9.},
2525 {6., 7., 8., 7., 6.},
2526 {9., 10., 11., 10., 9.}},
2527
2528 {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}},
2529
2530 {{9., 10., 11., 10., 9.},
2531 {6., 7., 8., 7., 6.},
2532 {9., 10., 11., 10., 9.}},
2533
2534 {{15., 16., 17., 16., 15.},
2535 {12., 13., 14., 13., 12.},
2536 {15., 16., 17., 16., 15.}},
2537
2538 {{9., 10., 11., 10., 9.},
2539 {6., 7., 8., 7., 6.},
2540 {9., 10., 11., 10., 9.}},
2541
2542 {{3., 4., 5., 4., 3.},
2543 {0., 1., 2., 1., 0.},
2544 {3., 4., 5., 4., 3.}}}}},
2545 torch::kDouble);
2546 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 3, 5}));
2547 ASSERT_TRUE(output.allclose(expected, 1e-04));
2548 }
2549 }
TEST_F(FunctionalTest,Pad7)2550 TEST_F(FunctionalTest, Pad7) {
2551 {
2552 auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2553 auto output = F::pad(
2554 input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0));
2555 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2556 auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2557 }
2558 }
TEST_F(FunctionalTest,Pad8)2559 TEST_F(FunctionalTest, Pad8) {
2560 {
2561 auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2562 auto output = F::pad(input, F::PadFuncOptions({1, 1}));
2563 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2564 auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2565 }
2566 }
2567
TEST_F(FunctionalTest,CTCLoss)2568 TEST_F(FunctionalTest, CTCLoss) {
2569 { // test CTCLoss typechecks
2570 const auto target_lengths = torch::tensor({30, 25, 20});
2571 const auto input_lengths = torch::tensor({50, 50, 50});
2572 const auto targets =
2573 torch::randint(1, 15, {target_lengths.sum().item<int>()}, torch::kInt);
2574 const auto log_probs =
2575 torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2576
2577 const auto _input_lengths = input_lengths.to(torch::kFloat);
2578 ASSERT_THROWS_WITH(
2579 F::ctc_loss(log_probs, targets, _input_lengths, target_lengths),
2580 "input_lengths must be integral");
2581
2582 const auto target_lengths_ = target_lengths.to(torch::kFloat);
2583 ASSERT_THROWS_WITH(
2584 F::ctc_loss(log_probs, targets, input_lengths, target_lengths_),
2585 "target_lengths must be integral");
2586 }
2587 { // test CTCLoss length checks
2588 const auto target_lengths = torch::tensor({30, 25, 20});
2589 const auto input_lengths = torch::tensor({50, 50, 50});
2590 const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt);
2591 const auto log_probs =
2592 torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2593 ASSERT_THROWS_WITH(
2594 F::ctc_loss(log_probs, targets, input_lengths, target_lengths),
2595 "Expected tensor to have size at least 30 at dimension 1");
2596 }
2597 { // test CTCLoss empty target
2598 {
2599 const auto target_lengths = torch::tensor({0, 0, 0});
2600 const auto input_lengths = torch::tensor({50, 50, 50});
2601 const auto targets =
2602 torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
2603 const auto log_probs =
2604 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2605 const auto loss = F::ctc_loss(
2606 log_probs,
2607 targets,
2608 input_lengths,
2609 target_lengths,
2610 F::CTCLossFuncOptions().reduction(torch::kNone));
2611 ASSERT_TRUE(loss.ge(0).all().item<bool>());
2612 ASSERT_TRUE(torch::allclose(
2613 -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss));
2614 }
2615 {
2616 const auto target_lengths = torch::tensor({0, 9, 0});
2617 const auto input_lengths = torch::tensor({50, 50, 50});
2618 const auto targets = torch::randint(1, 15, {9}, torch::kLong);
2619 const auto log_probs =
2620 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2621 const auto loss = F::ctc_loss(
2622 log_probs,
2623 targets,
2624 input_lengths,
2625 target_lengths,
2626 F::CTCLossFuncOptions().reduction(torch::kNone));
2627 ASSERT_TRUE(loss.ge(0).all().item<bool>());
2628 ASSERT_TRUE(torch::allclose(
2629 -log_probs.sum(0)
2630 .index_select(0, torch::tensor({0, 2}, torch::kLong))
2631 .slice(1, 0, 1)
2632 .view({2}),
2633 loss.index_select(0, torch::tensor({0, 2}, torch::kLong))));
2634 }
2635 }
2636 }
2637
TEST_F(FunctionalTest,PoissonNLLLoss)2638 TEST_F(FunctionalTest, PoissonNLLLoss) {
2639 const auto input = torch::tensor({0.5, 1.5, 2.5});
2640 const auto target = torch::tensor({1., 2., 3.});
2641 const auto component_wise_loss = torch::exp(input) - target * input;
2642 ASSERT_TRUE(torch::allclose(
2643 torch::mean(component_wise_loss), F::poisson_nll_loss(input, target)));
2644 ASSERT_TRUE(torch::allclose(
2645 component_wise_loss,
2646 F::poisson_nll_loss(
2647 input,
2648 target,
2649 F::PoissonNLLLossFuncOptions().reduction(torch::kNone))));
2650 ASSERT_TRUE(torch::allclose(
2651 torch::sum(component_wise_loss),
2652 F::poisson_nll_loss(
2653 input,
2654 target,
2655 F::PoissonNLLLossFuncOptions().reduction(torch::kSum))));
2656 ASSERT_TRUE(torch::allclose(
2657 torch::mean(component_wise_loss),
2658 F::poisson_nll_loss(
2659 input,
2660 target,
2661 F::PoissonNLLLossFuncOptions().reduction(torch::kMean))));
2662 }
2663
TEST_F(FunctionalTest,MarginRankingLoss)2664 TEST_F(FunctionalTest, MarginRankingLoss) {
2665 {
2666 const auto input1 = torch::randn(15) * 10;
2667 const auto input2 = torch::randn(15) * 10;
2668 const auto target = torch::randn(15).sign();
2669 ASSERT_TRUE(torch::allclose(
2670 F::margin_ranking_loss(input1, input2, target),
2671 (-target * (input1 - input2)).clamp(0).mean()));
2672 }
2673 {
2674 const auto input1 = torch::randn(15) * 10;
2675 const auto input2 = torch::randn(15) * 10;
2676 const auto target = torch::randn(15).sign();
2677 const auto margin = 0.5;
2678 ASSERT_TRUE(torch::allclose(
2679 F::margin_ranking_loss(
2680 input1,
2681 input2,
2682 target,
2683 F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2684 torch::kSum)),
2685 (-target * (input1 - input2) + margin).clamp(0).sum()));
2686 }
2687 {
2688 const auto input1 = torch::randn(15) * 10;
2689 const auto input2 = torch::randn(15) * 10;
2690 const auto target = torch::randn(15).sign();
2691 const auto margin = 0.5;
2692 ASSERT_TRUE(torch::allclose(
2693 F::margin_ranking_loss(
2694 input1,
2695 input2,
2696 target,
2697 F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2698 torch::kMean)),
2699 (-target * (input1 - input2) + margin).clamp(0).mean()));
2700 }
2701 }
2702
TEST_F(FunctionalTest,ConvTranspose1d)2703 TEST_F(FunctionalTest, ConvTranspose1d) {
2704 auto x = torch::arange(20.).view({2, 2, 5});
2705 auto weight = torch::arange(18.).view({2, 3, 3});
2706 auto y =
2707 F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
2708 auto expected = torch::tensor(
2709 {{{45., 104., 179., 212., 245., 188., 107.},
2710 {60., 140., 242., 293., 344., 260., 146.},
2711 {75., 176., 305., 374., 443., 332., 185.}},
2712 {{135., 304., 509., 542., 575., 428., 237.},
2713 {210., 460., 752., 803., 854., 620., 336.},
2714 {285., 616., 995., 1064., 1133., 812., 435.}}});
2715 ASSERT_TRUE(torch::allclose(y, expected));
2716
2717 auto y_no_options = F::conv_transpose1d(x, weight);
2718 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2719 }
2720
TEST_F(FunctionalTest,ConvTranspose2dEven)2721 TEST_F(FunctionalTest, ConvTranspose2dEven) {
2722 auto x = torch::arange(50.).view({1, 2, 5, 5});
2723 auto weight = torch::arange(54.).view({2, 3, 3, 3});
2724 auto y =
2725 F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2726 auto expected = torch::tensor(
2727 {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
2728 {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
2729 {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
2730 {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
2731 {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
2732 {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
2733 {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
2734 {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
2735 {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
2736 {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
2737 {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
2738 {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
2739 {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
2740 {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
2741 {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
2742 {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
2743 {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
2744 {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
2745 {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
2746 {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
2747 {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
2748 ASSERT_TRUE(torch::allclose(y, expected));
2749
2750 auto y_no_options = F::conv_transpose2d(x, weight);
2751 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2752 }
2753
TEST_F(FunctionalTest,ConvTranspose2dUneven)2754 TEST_F(FunctionalTest, ConvTranspose2dUneven) {
2755 auto x = torch::arange(40.).view({1, 2, 5, 4});
2756 auto weight = torch::arange(36.).view({2, 3, 3, 2});
2757 auto y =
2758 F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2759 auto expected = torch::tensor(
2760 {{{{360., 758., 796., 834., 440.},
2761 {832., 1752., 1836., 1920., 1012.},
2762 {1432., 3014., 3152., 3290., 1732.},
2763 {1696., 3566., 3704., 3842., 2020.},
2764 {1960., 4118., 4256., 4394., 2308.},
2765 {1504., 3152., 3252., 3352., 1756.},
2766 {856., 1790., 1844., 1898., 992.}},
2767 {{480., 1010., 1072., 1134., 596.},
2768 {1120., 2352., 2484., 2616., 1372.},
2769 {1936., 4058., 4268., 4478., 2344.},
2770 {2344., 4898., 5108., 5318., 2776.},
2771 {2752., 5738., 5948., 6158., 3208.},
2772 {2080., 4328., 4476., 4624., 2404.},
2773 {1168., 2426., 2504., 2582., 1340.}},
2774 {{600., 1262., 1348., 1434., 752.},
2775 {1408., 2952., 3132., 3312., 1732.},
2776 {2440., 5102., 5384., 5666., 2956.},
2777 {2992., 6230., 6512., 6794., 3532.},
2778 {3544., 7358., 7640., 7922., 4108.},
2779 {2656., 5504., 5700., 5896., 3052.},
2780 {1480., 3062., 3164., 3266., 1688.}}}});
2781 ASSERT_TRUE(torch::allclose(y, expected));
2782
2783 auto y_no_options = F::conv_transpose2d(x, weight);
2784 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2785 }
2786
TEST_F(FunctionalTest,ConvTranspose3d)2787 TEST_F(FunctionalTest, ConvTranspose3d) {
2788 auto x = torch::arange(16.).view({1, 2, 2, 2, 2});
2789 auto weight = torch::arange(32.).view({2, 2, 2, 2, 2});
2790 auto y =
2791 F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
2792 auto expected = torch::tensor(
2793 {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
2794 {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
2795 {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
2796 {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
2797 {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
2798 {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
2799 ASSERT_TRUE(torch::allclose(y, expected));
2800
2801 auto y_no_options = F::conv_transpose3d(x, weight);
2802 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2803 }
2804
TEST_F(FunctionalTest,AlphaDropout)2805 TEST_F(FunctionalTest, AlphaDropout) {
2806 auto input = torch::randn(5000);
2807 auto input_mean = input.mean();
2808 auto input_std = input.std();
2809
2810 for (const auto rate : {0.2, 0.5, 0.8}) {
2811 for (const auto inplace : {false, true}) {
2812 auto input_ = input.clone();
2813 auto output = F::alpha_dropout(
2814 input_,
2815 F::AlphaDropoutFuncOptions().p(rate).training(false).inplace(
2816 inplace));
2817 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2818 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2819 if (inplace) {
2820 ASSERT_TRUE(torch::allclose(input_, output));
2821 }
2822 }
2823 }
2824 auto output = F::detail::alpha_dropout(input, 0.5, false, false);
2825 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2826 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2827 }
2828
TEST_F(FunctionalTest,FeatureAlphaDropout)2829 TEST_F(FunctionalTest, FeatureAlphaDropout) {
2830 auto input = torch::randn(5000);
2831 auto input_mean = input.mean();
2832 auto input_std = input.std();
2833
2834 for (const auto rate : {0.2, 0.5, 0.8}) {
2835 for (const auto inplace : {false, true}) {
2836 auto input_ = input.clone();
2837 auto output = F::feature_alpha_dropout(
2838 input_,
2839 F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace(
2840 inplace));
2841 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2842 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2843 if (inplace) {
2844 ASSERT_TRUE(torch::allclose(input_, output));
2845 }
2846 }
2847 }
2848 auto output = F::feature_alpha_dropout(input);
2849 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2850 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2851 }
2852
TEST_F(FunctionalTest,Dropout)2853 TEST_F(FunctionalTest, Dropout) {
2854 auto input = torch::randn(5000);
2855 auto input_mean = input.mean();
2856 auto input_std = input.std();
2857
2858 for (const auto rate : {0.2, 0.5, 0.8}) {
2859 auto output = F::dropout(input, F::DropoutFuncOptions().p(rate));
2860 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2861 ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2862 }
2863 auto output = F::dropout(input);
2864 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2865 ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2866 ASSERT_TRUE(F::dropout(torch::tensor(1.)).defined());
2867 }
2868
TEST_F(FunctionalTest,Dropout2d)2869 TEST_F(FunctionalTest, Dropout2d) {
2870 auto input = torch::randn({2, 2, 50, 100});
2871 auto input_mean = input.mean();
2872 auto input_std = input.std();
2873
2874 for (const auto rate : {0.2, 0.5, 0.8}) {
2875 auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate));
2876 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2877 }
2878 auto output = F::dropout2d(input);
2879 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2880 ASSERT_TRUE(F::dropout2d(torch::randn({2, 50, 100})).defined());
2881 }
2882
TEST_F(FunctionalTest,Dropout3d)2883 TEST_F(FunctionalTest, Dropout3d) {
2884 auto input = torch::randn({2, 2, 50, 10, 10});
2885 auto input_mean = input.mean();
2886 auto input_std = input.std();
2887
2888 for (const auto rate : {0.2, 0.5, 0.8}) {
2889 auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate));
2890 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2891 }
2892 auto output = F::dropout3d(input);
2893 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2894 ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined());
2895 }
2896
2897 template <c10::ScalarType S, typename T>
test_isfinite(const at::Device & device)2898 void test_isfinite(const at::Device& device) {
2899 const std::vector<T> values = {
2900 std::numeric_limits<T>::lowest(),
2901 0,
2902 1,
2903 42,
2904 std::numeric_limits<T>::min(),
2905 std::numeric_limits<T>::max()};
2906 for (const auto value : values) {
2907 const auto x = torch::full(
2908 {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2909 ASSERT_TRUE(torch::isfinite(x).all().template item<bool>());
2910 }
2911 if (std::numeric_limits<T>::has_infinity) {
2912 const auto inf = std::numeric_limits<T>::infinity();
2913 const auto x = torch::tensor(
2914 {-inf,
2915 std::numeric_limits<T>::lowest(),
2916 static_cast<T>(0),
2917 static_cast<T>(1),
2918 static_cast<T>(42),
2919 std::numeric_limits<T>::min(),
2920 std::numeric_limits<T>::max(),
2921 inf},
2922 torch::TensorOptions().dtype(S).device(device));
2923 ASSERT_TRUE(torch::allclose(
2924 // torch::allclose does not support comparing torch::kBool
2925 torch::isfinite(x).toType(torch::kInt),
2926 torch::tensor(
2927 {false, true, true, true, true, true, true, false},
2928 torch::TensorOptions().device(device))
2929 .toType(torch::kInt)));
2930 }
2931 if (std::numeric_limits<T>::has_quiet_NaN) {
2932 const auto x = torch::tensor(
2933 {std::numeric_limits<T>::quiet_NaN()},
2934 torch::TensorOptions().dtype(S).device(device));
2935 ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2936 }
2937 if (std::numeric_limits<T>::has_signaling_NaN) {
2938 const auto x = torch::tensor(
2939 {std::numeric_limits<T>::signaling_NaN()},
2940 torch::TensorOptions().dtype(S).device(device));
2941 ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2942 }
2943 }
2944
TEST_F(FunctionalTest,isfinite)2945 TEST_F(FunctionalTest, isfinite) {
2946 const at::Device device("cpu");
2947 test_isfinite<torch::kUInt8, uint8_t>(device);
2948 test_isfinite<torch::kInt8, int8_t>(device);
2949 test_isfinite<torch::kInt16, int16_t>(device);
2950 test_isfinite<torch::kInt32, int32_t>(device);
2951 test_isfinite<torch::kInt64, int64_t>(device);
2952 test_isfinite<torch::kFloat32, float>(device);
2953 test_isfinite<torch::kFloat64, double>(device);
2954 }
2955
TEST_F(FunctionalTest,isfinite_CUDA)2956 TEST_F(FunctionalTest, isfinite_CUDA) {
2957 const at::Device device("cuda");
2958 test_isfinite<torch::kUInt8, uint8_t>(device);
2959 test_isfinite<torch::kInt8, int8_t>(device);
2960 test_isfinite<torch::kInt16, int16_t>(device);
2961 test_isfinite<torch::kInt32, int32_t>(device);
2962 test_isfinite<torch::kInt64, int64_t>(device);
2963 test_isfinite<torch::kFloat32, float>(device);
2964 test_isfinite<torch::kFloat64, double>(device);
2965 test_isfinite<torch::kFloat16, c10::Half>(device);
2966 }
2967
2968 template <c10::ScalarType S, typename T>
test_isinf(const at::Device & device)2969 void test_isinf(const at::Device& device) {
2970 const std::vector<T> values = {
2971 std::numeric_limits<T>::lowest(),
2972 0,
2973 1,
2974 42,
2975 std::numeric_limits<T>::min(),
2976 std::numeric_limits<T>::max()};
2977 for (const auto value : values) {
2978 const auto x = torch::full(
2979 {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2980 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2981 }
2982 if (std::numeric_limits<T>::has_infinity) {
2983 const auto inf = std::numeric_limits<T>::infinity();
2984 const auto x = torch::tensor(
2985 {-inf,
2986 std::numeric_limits<T>::lowest(),
2987 static_cast<T>(0),
2988 static_cast<T>(1),
2989 static_cast<T>(42),
2990 std::numeric_limits<T>::min(),
2991 std::numeric_limits<T>::max(),
2992 inf},
2993 torch::TensorOptions().dtype(S).device(device));
2994 ASSERT_TRUE(torch::allclose(
2995 // torch::allclose does not support comparing torch::kBool
2996 torch::isinf(x).toType(torch::kInt),
2997 torch::tensor(
2998 {true, false, false, false, false, false, false, true},
2999 torch::TensorOptions().device(device))
3000 .toType(torch::kInt)));
3001 }
3002 if (std::numeric_limits<T>::has_quiet_NaN) {
3003 const auto x = torch::tensor(
3004 {std::numeric_limits<T>::quiet_NaN()},
3005 torch::TensorOptions().dtype(S).device(device));
3006 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3007 }
3008 if (std::numeric_limits<T>::has_signaling_NaN) {
3009 const auto x = torch::tensor(
3010 {std::numeric_limits<T>::signaling_NaN()},
3011 torch::TensorOptions().dtype(S).device(device));
3012 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3013 }
3014 }
3015
TEST_F(FunctionalTest,isinf)3016 TEST_F(FunctionalTest, isinf) {
3017 const at::Device device("cpu");
3018 test_isinf<torch::kUInt8, uint8_t>(device);
3019 test_isinf<torch::kInt8, int8_t>(device);
3020 test_isinf<torch::kInt16, int16_t>(device);
3021 test_isinf<torch::kInt32, int32_t>(device);
3022 test_isinf<torch::kInt64, int64_t>(device);
3023 test_isinf<torch::kFloat32, float>(device);
3024 test_isinf<torch::kFloat64, double>(device);
3025 }
3026
TEST_F(FunctionalTest,isinf_CUDA)3027 TEST_F(FunctionalTest, isinf_CUDA) {
3028 const at::Device device("cuda");
3029 test_isinf<torch::kUInt8, uint8_t>(device);
3030 test_isinf<torch::kInt8, int8_t>(device);
3031 test_isinf<torch::kInt16, int16_t>(device);
3032 test_isinf<torch::kInt32, int32_t>(device);
3033 test_isinf<torch::kInt64, int64_t>(device);
3034 test_isinf<torch::kFloat32, float>(device);
3035 test_isinf<torch::kFloat64, double>(device);
3036 test_isinf<torch::kFloat16, c10::Half>(device);
3037 }
3038
3039 template <c10::ScalarType S, typename T>
test_allclose(const at::Device & device)3040 void test_allclose(const at::Device& device) {
3041 const std::vector<T> values = {
3042 std::numeric_limits<T>::lowest(),
3043 0,
3044 1,
3045 42,
3046 std::numeric_limits<T>::min(),
3047 std::numeric_limits<T>::max()};
3048 for (const auto value : values) {
3049 const auto x =
3050 torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3051 const auto y =
3052 torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3053 ASSERT_TRUE(torch::allclose(x, x));
3054 ASSERT_TRUE(torch::allclose(x, y));
3055 ASSERT_TRUE(torch::allclose(y, x));
3056 ASSERT_FALSE(torch::allclose(1.1 * x + 0.1, 1.0 * x));
3057 ASSERT_TRUE(torch::allclose(0.99 * x + 0.1, 1.0 * x, 1.1, 0.1));
3058 }
3059 if (std::numeric_limits<T>::has_infinity) {
3060 const auto inf = std::numeric_limits<T>::infinity();
3061 const auto x = torch::tensor(
3062 {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3063 const auto y = torch::tensor(
3064 {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3065 ASSERT_TRUE(torch::allclose(x, x));
3066 ASSERT_TRUE(torch::allclose(x, y));
3067 ASSERT_TRUE(torch::allclose(y, x));
3068 }
3069 if (std::numeric_limits<T>::has_quiet_NaN) {
3070 const auto x = torch::tensor(
3071 {std::numeric_limits<T>::quiet_NaN()},
3072 torch::TensorOptions().dtype(S).device(device));
3073 const auto y = torch::tensor(
3074 {std::numeric_limits<T>::quiet_NaN()},
3075 torch::TensorOptions().dtype(S).device(device));
3076 ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3077 ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3078 ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3079 }
3080 if (std::numeric_limits<T>::has_signaling_NaN) {
3081 const auto x = torch::tensor(
3082 {std::numeric_limits<T>::signaling_NaN()},
3083 torch::TensorOptions().dtype(S).device(device));
3084 const auto y = torch::tensor(
3085 {std::numeric_limits<T>::signaling_NaN()},
3086 torch::TensorOptions().dtype(S).device(device));
3087 ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3088 ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3089 ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3090 }
3091 }
3092
TEST_F(FunctionalTest,AllClose)3093 TEST_F(FunctionalTest, AllClose) {
3094 const at::Device device("cpu");
3095 test_allclose<torch::kUInt8, uint8_t>(device);
3096 test_allclose<torch::kInt8, int8_t>(device);
3097 test_allclose<torch::kInt16, int16_t>(device);
3098 test_allclose<torch::kInt32, int32_t>(device);
3099 test_allclose<torch::kInt64, int64_t>(device);
3100 test_allclose<torch::kFloat32, float>(device);
3101 test_allclose<torch::kFloat64, double>(device);
3102 }
3103
TEST_F(FunctionalTest,AllClose_CUDA)3104 TEST_F(FunctionalTest, AllClose_CUDA) {
3105 const at::Device device("cuda");
3106 test_allclose<torch::kUInt8, uint8_t>(device);
3107 test_allclose<torch::kInt8, int8_t>(device);
3108 test_allclose<torch::kInt16, int16_t>(device);
3109 test_allclose<torch::kInt32, int32_t>(device);
3110 test_allclose<torch::kInt64, int64_t>(device);
3111 test_allclose<torch::kFloat32, float>(device);
3112 test_allclose<torch::kFloat64, double>(device);
3113 test_allclose<torch::kFloat16, c10::Half>(device);
3114 }
3115
TEST_F(FunctionalTest,BCEWithLogitsLoss)3116 TEST_F(FunctionalTest, BCEWithLogitsLoss) {
3117 { // test BCE with logits raises if target and input are different size
3118 {
3119 const auto target = torch::rand(5);
3120 const auto input = torch::rand({5, 1});
3121 ASSERT_THROWS_WITH(
3122 F::binary_cross_entropy_with_logits(input, target),
3123 "must be the same as input size");
3124 }
3125 {
3126 const auto target = torch::rand({5, 1});
3127 const auto input = torch::rand(5);
3128 ASSERT_THROWS_WITH(
3129 F::binary_cross_entropy_with_logits(input, target),
3130 "must be the same as input size");
3131 }
3132 }
3133 { // test BCE with logits gives same result as sigmoid and bce loss
3134 auto sigmoid = Sigmoid();
3135
3136 auto target = torch::rand({64, 4});
3137 auto output = torch::rand({64, 4}) - 0.5;
3138
3139 ASSERT_TRUE(torch::allclose(
3140 F::binary_cross_entropy_with_logits(output, target),
3141 F::binary_cross_entropy(sigmoid(output), target)));
3142
3143 auto weight = torch::rand(4);
3144 ASSERT_TRUE(torch::allclose(
3145 F::binary_cross_entropy_with_logits(
3146 output,
3147 target,
3148 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3149 F::binary_cross_entropy(
3150 sigmoid(output),
3151 target,
3152 F::BinaryCrossEntropyFuncOptions().weight(weight))));
3153
3154 target = torch::zeros({4, 1}, torch::kFloat);
3155 output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3156
3157 ASSERT_TRUE(torch::allclose(
3158 F::binary_cross_entropy_with_logits(output, target),
3159 F::binary_cross_entropy(sigmoid(output), target)));
3160
3161 ASSERT_TRUE(torch::allclose(
3162 F::binary_cross_entropy_with_logits(
3163 output,
3164 target,
3165 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(
3166 torch::kNone)),
3167 F::binary_cross_entropy(
3168 sigmoid(output),
3169 target,
3170 F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))));
3171
3172 weight = torch::rand({1}, torch::kFloat);
3173 ASSERT_TRUE(torch::allclose(
3174 F::binary_cross_entropy_with_logits(
3175 output,
3176 target,
3177 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3178 F::binary_cross_entropy(
3179 sigmoid(output),
3180 target,
3181 F::BinaryCrossEntropyFuncOptions().weight(weight))));
3182 }
3183 { // test BCE with logits has correct grad at zero
3184 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3185 const auto target = torch::zeros({3, 1});
3186 F::binary_cross_entropy_with_logits(
3187 output,
3188 target,
3189 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum))
3190 .backward();
3191 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3192 ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3193 }
3194 { // test BCE with logits broadcasts weights
3195 const auto target = torch::rand({16, 4});
3196 const auto output = torch::rand({16, 4}) - 0.5;
3197
3198 auto weight = torch::rand(4);
3199 auto out1 = F::binary_cross_entropy_with_logits(
3200 output,
3201 target,
3202 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3203
3204 weight = weight.expand({16, 4}).contiguous();
3205 auto out2 = F::binary_cross_entropy_with_logits(
3206 output,
3207 target,
3208 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3209
3210 ASSERT_TRUE(torch::allclose(out1, out2));
3211
3212 weight = torch::rand({16, 1});
3213 out1 = F::binary_cross_entropy_with_logits(
3214 output,
3215 target,
3216 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3217
3218 weight = weight.expand({16, 4}).contiguous();
3219 out2 = F::binary_cross_entropy_with_logits(
3220 output,
3221 target,
3222 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3223
3224 ASSERT_TRUE(torch::allclose(out1, out2));
3225 }
3226 { // test BCE with logits ones in pos weights are the same as none
3227 const auto target = torch::rand({64, 4});
3228 const auto output = torch::rand({64, 4}) - 0.5;
3229 const auto pos_weight = torch::ones({64, 4});
3230
3231 ASSERT_TRUE(torch::allclose(
3232 F::binary_cross_entropy_with_logits(output, target),
3233 F::binary_cross_entropy_with_logits(
3234 output,
3235 target,
3236 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(
3237 pos_weight))));
3238 }
3239 { // test BCE with logits broadcasts pos weights
3240 const auto target = torch::rand({64, 4});
3241 const auto output = torch::rand({64, 4}) - 0.5;
3242 const auto pos_weight = torch::rand(4);
3243 const auto out1 = F::binary_cross_entropy_with_logits(
3244 output,
3245 target,
3246 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3247
3248 const auto pos_weight1 = pos_weight.expand({1, 4});
3249 const auto out2 = F::binary_cross_entropy_with_logits(
3250 output,
3251 target,
3252 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3253
3254 const auto pos_weight2 = pos_weight.expand({64, 4});
3255 const auto out3 = F::binary_cross_entropy_with_logits(
3256 output,
3257 target,
3258 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3259
3260 ASSERT_TRUE(torch::allclose(out1, out2));
3261 ASSERT_TRUE(torch::allclose(out1, out3));
3262 }
3263 { // test BCE with logits with pos weight has correct grad at zero
3264 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3265 const auto target = torch::zeros({3, 1});
3266 const auto pos_weight = torch::ones({3, 1});
3267 F::binary_cross_entropy_with_logits(
3268 output,
3269 target,
3270 F::BinaryCrossEntropyWithLogitsFuncOptions()
3271 .pos_weight(pos_weight)
3272 .reduction(torch::kSum))
3273 .backward();
3274 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3275 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3276 const auto grad = output.grad();
3277 ASSERT_TRUE(torch::allclose(grad, expected_grad));
3278 }
3279 { // test BCE with logits stability
3280 const auto output = torch::tensor({0., -120.});
3281 const auto target = torch::tensor({0., 1.});
3282 const auto pos_weight = torch::tensor({1., 1.});
3283
3284 const auto out1 = F::binary_cross_entropy_with_logits(output, target);
3285 ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3286
3287 const auto out2 = F::binary_cross_entropy_with_logits(
3288 output,
3289 target,
3290 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3291 ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3292 }
3293 }
3294