xref: /aosp_15_r20/external/pytorch/test/cpp/api/functional.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace F = torch::nn::functional;
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct FunctionalTest : torch::test::SeedingFixture {};
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Conv1d)14*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Conv1d) {
15*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
16*da0073e9SAndroid Build Coastguard Worker                .reshape({2, 3, 5});
17*da0073e9SAndroid Build Coastguard Worker   auto weight =
18*da0073e9SAndroid Build Coastguard Worker       torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true))
19*da0073e9SAndroid Build Coastguard Worker           .reshape({2, 3, 3});
20*da0073e9SAndroid Build Coastguard Worker   auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
21*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
22*da0073e9SAndroid Build Coastguard Worker       {{{312., 348., 384.}, {798., 915., 1032.}},
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker        {{852., 888., 924.}, {2553., 2670., 2787.}}},
25*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
26*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv1d(x, weight);
29*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Conv2dEven)32*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Conv2dEven) {
33*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
34*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 5});
35*da0073e9SAndroid Build Coastguard Worker   auto weight =
36*da0073e9SAndroid Build Coastguard Worker       torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true))
37*da0073e9SAndroid Build Coastguard Worker           .reshape({2, 3, 3, 3});
38*da0073e9SAndroid Build Coastguard Worker   auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
39*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
40*da0073e9SAndroid Build Coastguard Worker       {{{{15219., 15570., 15921.},
41*da0073e9SAndroid Build Coastguard Worker          {16974., 17325., 17676.},
42*da0073e9SAndroid Build Coastguard Worker          {18729., 19080., 19431.}},
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker         {{37818., 38898., 39978.},
45*da0073e9SAndroid Build Coastguard Worker          {43218., 44298., 45378.},
46*da0073e9SAndroid Build Coastguard Worker          {48618., 49698., 50778.}}}},
47*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
48*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
49*da0073e9SAndroid Build Coastguard Worker 
50*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv2d(x, weight);
51*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
52*da0073e9SAndroid Build Coastguard Worker }
53*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Conv2dUneven)54*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Conv2dUneven) {
55*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
56*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 4});
57*da0073e9SAndroid Build Coastguard Worker   auto weight =
58*da0073e9SAndroid Build Coastguard Worker       torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true))
59*da0073e9SAndroid Build Coastguard Worker           .reshape({2, 3, 3, 2});
60*da0073e9SAndroid Build Coastguard Worker   auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
61*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
62*da0073e9SAndroid Build Coastguard Worker       {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker         {{13227., 13704., 14181.},
65*da0073e9SAndroid Build Coastguard Worker          {15135., 15612., 16089.},
66*da0073e9SAndroid Build Coastguard Worker          {17043., 17520., 17997.}}}},
67*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv2d(x, weight);
71*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Conv3d)74*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Conv3d) {
75*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
76*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 5, 5});
77*da0073e9SAndroid Build Coastguard Worker   auto weight =
78*da0073e9SAndroid Build Coastguard Worker       torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true))
79*da0073e9SAndroid Build Coastguard Worker           .reshape({2, 3, 3, 3, 3});
80*da0073e9SAndroid Build Coastguard Worker   auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
81*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
82*da0073e9SAndroid Build Coastguard Worker       {{{{{700704., 703944., 707184.},
83*da0073e9SAndroid Build Coastguard Worker           {716904., 720144., 723384.},
84*da0073e9SAndroid Build Coastguard Worker           {733104., 736344., 739584.}},
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker          {{781704., 784944., 788184.},
87*da0073e9SAndroid Build Coastguard Worker           {797904., 801144., 804384.},
88*da0073e9SAndroid Build Coastguard Worker           {814104., 817344., 820584.}},
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker          {{862704., 865944., 869184.},
91*da0073e9SAndroid Build Coastguard Worker           {878904., 882144., 885384.},
92*da0073e9SAndroid Build Coastguard Worker           {895104., 898344., 901584.}}},
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker         {{{1724220., 1734021., 1743822.},
95*da0073e9SAndroid Build Coastguard Worker           {1773225., 1783026., 1792827.},
96*da0073e9SAndroid Build Coastguard Worker           {1822230., 1832031., 1841832.}},
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker          {{1969245., 1979046., 1988847.},
99*da0073e9SAndroid Build Coastguard Worker           {2018250., 2028051., 2037852.},
100*da0073e9SAndroid Build Coastguard Worker           {2067255., 2077056., 2086857.}},
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker          {{2214270., 2224071., 2233872.},
103*da0073e9SAndroid Build Coastguard Worker           {2263275., 2273076., 2282877.},
104*da0073e9SAndroid Build Coastguard Worker           {2312280., 2322081., 2331882.}}}}},
105*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
106*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
107*da0073e9SAndroid Build Coastguard Worker 
108*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv3d(x, weight);
109*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
110*da0073e9SAndroid Build Coastguard Worker }
111*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxPool1d)112*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxPool1d) {
113*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
114*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2));
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
117*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
118*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
119*da0073e9SAndroid Build Coastguard Worker }
120*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxPool2d)121*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxPool2d) {
122*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5});
123*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2));
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
126*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
127*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxPool2dBackward)130*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxPool2dBackward) {
131*da0073e9SAndroid Build Coastguard Worker   auto input = torch::rand(
132*da0073e9SAndroid Build Coastguard Worker       {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true));
133*da0073e9SAndroid Build Coastguard Worker   auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2));
134*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
135*da0073e9SAndroid Build Coastguard Worker   s.backward();
136*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
137*da0073e9SAndroid Build Coastguard Worker }
138*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxPool3d)139*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxPool3d) {
140*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5});
141*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2));
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
144*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
145*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
146*da0073e9SAndroid Build Coastguard Worker }
147*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AvgPool1d)148*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AvgPool1d) {
149*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
150*da0073e9SAndroid Build Coastguard Worker   auto y = F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2));
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
153*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
154*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
155*da0073e9SAndroid Build Coastguard Worker }
156*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AvgPool2d)157*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AvgPool2d) {
158*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5});
159*da0073e9SAndroid Build Coastguard Worker   auto y = F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2));
160*da0073e9SAndroid Build Coastguard Worker 
161*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
162*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
163*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
164*da0073e9SAndroid Build Coastguard Worker }
165*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AvgPool3d)166*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AvgPool3d) {
167*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5});
168*da0073e9SAndroid Build Coastguard Worker   auto y = F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2));
169*da0073e9SAndroid Build Coastguard Worker 
170*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
171*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
172*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
173*da0073e9SAndroid Build Coastguard Worker }
174*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,FractionalMaxPool2d)175*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, FractionalMaxPool2d) {
176*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5});
177*da0073e9SAndroid Build Coastguard Worker   auto y = F::fractional_max_pool2d(
178*da0073e9SAndroid Build Coastguard Worker       x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
179*da0073e9SAndroid Build Coastguard Worker 
180*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
181*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
182*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
183*da0073e9SAndroid Build Coastguard Worker 
184*da0073e9SAndroid Build Coastguard Worker   auto y_with_indices = F::fractional_max_pool2d_with_indices(
185*da0073e9SAndroid Build Coastguard Worker       x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
186*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
187*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
188*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y_with_indices),
189*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
190*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
191*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2}));
192*da0073e9SAndroid Build Coastguard Worker 
193*da0073e9SAndroid Build Coastguard Worker   auto x1 = torch::ones({2, 2, 5, 5});
194*da0073e9SAndroid Build Coastguard Worker   auto y1 = F::fractional_max_pool2d(
195*da0073e9SAndroid Build Coastguard Worker       x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
196*da0073e9SAndroid Build Coastguard Worker 
197*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y1.ndimension(), 4);
198*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2})));
199*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y1.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
200*da0073e9SAndroid Build Coastguard Worker 
201*da0073e9SAndroid Build Coastguard Worker   auto y1_with_indices = F::fractional_max_pool2d_with_indices(
202*da0073e9SAndroid Build Coastguard Worker       x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
203*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices)));
204*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
205*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y1_with_indices),
206*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
207*da0073e9SAndroid Build Coastguard Worker           {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}},
208*da0073e9SAndroid Build Coastguard Worker            {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}})));
209*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
210*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y1_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,FractionalMaxPool3d)213*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, FractionalMaxPool3d) {
214*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5});
215*da0073e9SAndroid Build Coastguard Worker   auto y = F::fractional_max_pool3d(
216*da0073e9SAndroid Build Coastguard Worker       x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
217*da0073e9SAndroid Build Coastguard Worker 
218*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
219*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
220*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
221*da0073e9SAndroid Build Coastguard Worker 
222*da0073e9SAndroid Build Coastguard Worker   auto y_with_indices = F::fractional_max_pool3d_with_indices(
223*da0073e9SAndroid Build Coastguard Worker       x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
224*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
225*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
226*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y_with_indices),
227*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
228*da0073e9SAndroid Build Coastguard Worker           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
229*da0073e9SAndroid Build Coastguard Worker            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
230*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
231*da0073e9SAndroid Build Coastguard Worker       std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LPPool1d)234*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LPPool1d) {
235*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
236*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
237*da0073e9SAndroid Build Coastguard Worker   int kernel_size = 3;
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
240*da0073e9SAndroid Build Coastguard Worker   auto y = F::lp_pool1d(
241*da0073e9SAndroid Build Coastguard Worker       x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride));
242*da0073e9SAndroid Build Coastguard Worker   auto expected =
243*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
244*da0073e9SAndroid Build Coastguard Worker        kernel_size)
245*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
246*da0073e9SAndroid Build Coastguard Worker 
247*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
248*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
249*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
250*da0073e9SAndroid Build Coastguard Worker }
251*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LPPool2d)252*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LPPool2d) {
253*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
254*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
255*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> kernel_size({2, 3});
256*da0073e9SAndroid Build Coastguard Worker 
257*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 2, 5});
258*da0073e9SAndroid Build Coastguard Worker   auto y = F::lp_pool2d(
259*da0073e9SAndroid Build Coastguard Worker       x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
260*da0073e9SAndroid Build Coastguard Worker   auto expected =
261*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
262*da0073e9SAndroid Build Coastguard Worker        (kernel_size[0] * kernel_size[1]))
263*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
264*da0073e9SAndroid Build Coastguard Worker 
265*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
266*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
267*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LPPool3d)270*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LPPool3d) {
271*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
272*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
273*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> kernel_size({1, 2, 3});
274*da0073e9SAndroid Build Coastguard Worker 
275*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 1, 2, 5});
276*da0073e9SAndroid Build Coastguard Worker   auto y = F::lp_pool3d(
277*da0073e9SAndroid Build Coastguard Worker       x, F::LPPool3dFuncOptions(norm_type, kernel_size).stride(stride));
278*da0073e9SAndroid Build Coastguard Worker   auto expected =
279*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
280*da0073e9SAndroid Build Coastguard Worker        (kernel_size[0] * kernel_size[1] * kernel_size[2]))
281*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
282*da0073e9SAndroid Build Coastguard Worker 
283*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 5);
284*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
285*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
286*da0073e9SAndroid Build Coastguard Worker }
287*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CosineSimilarity)288*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CosineSimilarity) {
289*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
290*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
291*da0073e9SAndroid Build Coastguard Worker   auto output = F::cosine_similarity(
292*da0073e9SAndroid Build Coastguard Worker       input1, input2, F::CosineSimilarityFuncOptions().dim(1));
293*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
294*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
295*da0073e9SAndroid Build Coastguard Worker }
296*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SmoothL1LossDefaultOptions)297*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
298*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
299*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
300*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
301*da0073e9SAndroid Build Coastguard Worker   auto output = F::smooth_l1_loss(input, target);
302*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.0233335, torch::kFloat);
303*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
304*da0073e9SAndroid Build Coastguard Worker   s.backward();
305*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
306*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
307*da0073e9SAndroid Build Coastguard Worker }
308*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SmoothL1LossBeta)309*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SmoothL1LossBeta) {
310*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
311*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
312*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
313*da0073e9SAndroid Build Coastguard Worker   auto output =
314*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment)
315*da0073e9SAndroid Build Coastguard Worker       F::smooth_l1_loss(
316*da0073e9SAndroid Build Coastguard Worker           input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
317*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(1.67, torch::kFloat);
318*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
319*da0073e9SAndroid Build Coastguard Worker   s.backward();
320*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
321*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
322*da0073e9SAndroid Build Coastguard Worker }
323*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SmoothL1LossBetaOptions)324*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SmoothL1LossBetaOptions) {
325*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
326*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
327*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
328*da0073e9SAndroid Build Coastguard Worker   auto output =
329*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
330*da0073e9SAndroid Build Coastguard Worker       F::smooth_l1_loss(
331*da0073e9SAndroid Build Coastguard Worker           input,
332*da0073e9SAndroid Build Coastguard Worker           target,
333*da0073e9SAndroid Build Coastguard Worker           F::SmoothL1LossFuncOptions().reduction(torch::kMean).beta(0.5));
334*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(1.67, torch::kFloat);
335*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
336*da0073e9SAndroid Build Coastguard Worker   s.backward();
337*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
338*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
339*da0073e9SAndroid Build Coastguard Worker }
340*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SmoothL1LossNoReduction)341*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
342*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
343*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
344*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
345*da0073e9SAndroid Build Coastguard Worker   auto output =
346*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(bugprone-argument-comment)
347*da0073e9SAndroid Build Coastguard Worker       F::smooth_l1_loss(input, target, /*reduction=*/torch::kNone);
348*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
349*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
350*da0073e9SAndroid Build Coastguard Worker   s.backward();
351*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
352*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
353*da0073e9SAndroid Build Coastguard Worker }
354*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,HuberLossDefaultOptions)355*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, HuberLossDefaultOptions) {
356*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
357*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
358*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
359*da0073e9SAndroid Build Coastguard Worker   auto output = F::huber_loss(input, target);
360*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.0233335, torch::kFloat);
361*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
362*da0073e9SAndroid Build Coastguard Worker   s.backward();
363*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
364*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
365*da0073e9SAndroid Build Coastguard Worker }
366*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,HuberLossDelta)367*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, HuberLossDelta) {
368*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
369*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
370*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
371*da0073e9SAndroid Build Coastguard Worker   auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5);
372*da0073e9SAndroid Build Coastguard Worker   auto output = F::huber_loss(input, target, options);
373*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(1.67 * 0.5, torch::kFloat);
374*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
375*da0073e9SAndroid Build Coastguard Worker   s.backward();
376*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
377*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
378*da0073e9SAndroid Build Coastguard Worker }
379*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,HuberLossNoReduction)380*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, HuberLossNoReduction) {
381*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
382*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
383*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
384*da0073e9SAndroid Build Coastguard Worker   auto options = F::HuberLossFuncOptions().reduction(torch::kNone);
385*da0073e9SAndroid Build Coastguard Worker   auto output = F::huber_loss(input, target, options);
386*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
387*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
388*da0073e9SAndroid Build Coastguard Worker   s.backward();
389*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
390*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.sizes() == input.grad().sizes());
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SoftMarginLossDefaultOptions)393*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) {
394*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
395*da0073e9SAndroid Build Coastguard Worker       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
396*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
397*da0073e9SAndroid Build Coastguard Worker   auto output = F::soft_margin_loss(input, target);
398*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({1.3767317}, torch::kFloat);
399*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
400*da0073e9SAndroid Build Coastguard Worker   s.backward();
401*da0073e9SAndroid Build Coastguard Worker 
402*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
403*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
404*da0073e9SAndroid Build Coastguard Worker }
405*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MultiLabelSoftMarginLossDefaultOptions)406*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
407*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
408*da0073e9SAndroid Build Coastguard Worker       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
409*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
410*da0073e9SAndroid Build Coastguard Worker   auto target =
411*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
412*da0073e9SAndroid Build Coastguard Worker   auto output = F::multilabel_soft_margin_loss(input, target);
413*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.7608436}, torch::kFloat);
414*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
415*da0073e9SAndroid Build Coastguard Worker   s.backward();
416*da0073e9SAndroid Build Coastguard Worker 
417*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
418*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
419*da0073e9SAndroid Build Coastguard Worker }
420*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SoftMarginLossNoReduction)421*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SoftMarginLossNoReduction) {
422*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
423*da0073e9SAndroid Build Coastguard Worker       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
424*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
425*da0073e9SAndroid Build Coastguard Worker   auto output = F::soft_margin_loss(input, target, torch::kNone);
426*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
427*da0073e9SAndroid Build Coastguard Worker       {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
428*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
429*da0073e9SAndroid Build Coastguard Worker   s.backward();
430*da0073e9SAndroid Build Coastguard Worker 
431*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
432*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
433*da0073e9SAndroid Build Coastguard Worker }
434*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MultiLabelSoftMarginLossWeightedNoReduction)435*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) {
436*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
437*da0073e9SAndroid Build Coastguard Worker       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
438*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
439*da0073e9SAndroid Build Coastguard Worker   auto target =
440*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
441*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
442*da0073e9SAndroid Build Coastguard Worker   auto options = F::MultilabelSoftMarginLossFuncOptions()
443*da0073e9SAndroid Build Coastguard Worker                      .reduction(torch::kNone)
444*da0073e9SAndroid Build Coastguard Worker                      .weight(weight);
445*da0073e9SAndroid Build Coastguard Worker   auto output = F::multilabel_soft_margin_loss(input, target, options);
446*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
447*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
448*da0073e9SAndroid Build Coastguard Worker   s.backward();
449*da0073e9SAndroid Build Coastguard Worker 
450*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
451*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
452*da0073e9SAndroid Build Coastguard Worker }
453*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PairwiseDistance)454*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PairwiseDistance) {
455*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
456*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
457*da0073e9SAndroid Build Coastguard Worker   auto output = F::pairwise_distance(
458*da0073e9SAndroid Build Coastguard Worker       input1, input2, F::PairwiseDistanceFuncOptions().p(1));
459*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({6, 6}, torch::kFloat);
460*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
461*da0073e9SAndroid Build Coastguard Worker }
462*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PDist)463*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PDist) {
464*da0073e9SAndroid Build Coastguard Worker   {
465*da0073e9SAndroid Build Coastguard Worker     auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}});
466*da0073e9SAndroid Build Coastguard Worker     auto output = F::pdist(input);
467*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor({11.7898});
468*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
469*da0073e9SAndroid Build Coastguard Worker   }
470*da0073e9SAndroid Build Coastguard Worker   {
471*da0073e9SAndroid Build Coastguard Worker     auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}});
472*da0073e9SAndroid Build Coastguard Worker     auto output = F::pdist(input, 1.5);
473*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor({4.0, 4.8945, 2.0});
474*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
475*da0073e9SAndroid Build Coastguard Worker   }
476*da0073e9SAndroid Build Coastguard Worker }
477*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveMaxPool1d)478*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveMaxPool1d) {
479*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
480*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3));
481*da0073e9SAndroid Build Coastguard Worker 
482*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
483*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
484*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
485*da0073e9SAndroid Build Coastguard Worker }
486*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveMaxPool2d)487*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveMaxPool2d) {
488*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5});
489*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3));
490*da0073e9SAndroid Build Coastguard Worker 
491*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
492*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
493*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
494*da0073e9SAndroid Build Coastguard Worker }
495*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveMaxPool3d)496*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveMaxPool3d) {
497*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5});
498*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3));
499*da0073e9SAndroid Build Coastguard Worker 
500*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
501*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
502*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
503*da0073e9SAndroid Build Coastguard Worker }
504*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveAvgPool1d)505*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
506*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
507*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3));
508*da0073e9SAndroid Build Coastguard Worker 
509*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
510*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
511*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
512*da0073e9SAndroid Build Coastguard Worker }
513*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveAvgPool2d)514*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveAvgPool2d) {
515*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5});
516*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3));
517*da0073e9SAndroid Build Coastguard Worker 
518*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
519*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
520*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
521*da0073e9SAndroid Build Coastguard Worker }
522*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AdaptiveAvgPool3d)523*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AdaptiveAvgPool3d) {
524*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5});
525*da0073e9SAndroid Build Coastguard Worker   auto y = F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3));
526*da0073e9SAndroid Build Coastguard Worker 
527*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
528*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
529*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
530*da0073e9SAndroid Build Coastguard Worker }
531*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,L1Loss)532*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, L1Loss) {
533*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
534*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
535*da0073e9SAndroid Build Coastguard Worker   auto output = F::l1_loss(torch::sigmoid(input), target);
536*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
537*da0073e9SAndroid Build Coastguard Worker   s.backward();
538*da0073e9SAndroid Build Coastguard Worker 
539*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
540*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
541*da0073e9SAndroid Build Coastguard Worker }
542*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MSELoss)543*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MSELoss) {
544*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
545*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
546*da0073e9SAndroid Build Coastguard Worker   auto output = F::mse_loss(torch::sigmoid(input), target);
547*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
548*da0073e9SAndroid Build Coastguard Worker   s.backward();
549*da0073e9SAndroid Build Coastguard Worker 
550*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
551*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
552*da0073e9SAndroid Build Coastguard Worker }
553*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BCELoss)554*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BCELoss) {
555*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
556*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
557*da0073e9SAndroid Build Coastguard Worker   auto output = F::binary_cross_entropy(torch::sigmoid(input), target);
558*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
559*da0073e9SAndroid Build Coastguard Worker   s.backward();
560*da0073e9SAndroid Build Coastguard Worker 
561*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
562*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
563*da0073e9SAndroid Build Coastguard Worker }
564*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,KLDivLoss)565*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, KLDivLoss) {
566*da0073e9SAndroid Build Coastguard Worker   KLDivLoss loss;
567*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
568*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
569*da0073e9SAndroid Build Coastguard Worker   auto output = F::kl_div(torch::sigmoid(input), target);
570*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
571*da0073e9SAndroid Build Coastguard Worker   s.backward();
572*da0073e9SAndroid Build Coastguard Worker 
573*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
574*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
575*da0073e9SAndroid Build Coastguard Worker }
576*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,HingeEmbeddingLoss)577*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, HingeEmbeddingLoss) {
578*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat);
579*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
580*da0073e9SAndroid Build Coastguard Worker   auto output = F::hinge_embedding_loss(
581*da0073e9SAndroid Build Coastguard Worker       input, target, F::HingeEmbeddingLossFuncOptions().margin(2));
582*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({10}, torch::kFloat);
583*da0073e9SAndroid Build Coastguard Worker 
584*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
585*da0073e9SAndroid Build Coastguard Worker }
586*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,GridSample)587*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, GridSample) {
588*da0073e9SAndroid Build Coastguard Worker   auto input =
589*da0073e9SAndroid Build Coastguard Worker       torch::arange(9, torch::kFloat).view(std::vector<int64_t>({1, 1, 3, 3}));
590*da0073e9SAndroid Build Coastguard Worker   auto grid = torch::tensor(
591*da0073e9SAndroid Build Coastguard Worker       {{{{-2., -1.}, {-1., -1.}, {0., -1.}},
592*da0073e9SAndroid Build Coastguard Worker         {{-1., 0.}, {0., 0.}, {1., 0.}},
593*da0073e9SAndroid Build Coastguard Worker         {{0., 1.}, {1., 1.}, {2., 1.}}}},
594*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
595*da0073e9SAndroid Build Coastguard Worker 
596*da0073e9SAndroid Build Coastguard Worker   // bilinear, zeros, true
597*da0073e9SAndroid Build Coastguard Worker   auto options = F::GridSampleFuncOptions()
598*da0073e9SAndroid Build Coastguard Worker                      .mode(torch::kBilinear)
599*da0073e9SAndroid Build Coastguard Worker                      .padding_mode(torch::kZeros)
600*da0073e9SAndroid Build Coastguard Worker                      .align_corners(true);
601*da0073e9SAndroid Build Coastguard Worker   auto output = F::grid_sample(input, grid, options);
602*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
603*da0073e9SAndroid Build Coastguard Worker       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
604*da0073e9SAndroid Build Coastguard Worker 
605*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
606*da0073e9SAndroid Build Coastguard Worker 
607*da0073e9SAndroid Build Coastguard Worker   // bilinear, zeros, false
608*da0073e9SAndroid Build Coastguard Worker   options = F::GridSampleFuncOptions()
609*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBilinear)
610*da0073e9SAndroid Build Coastguard Worker                 .padding_mode(torch::kZeros)
611*da0073e9SAndroid Build Coastguard Worker                 .align_corners(false);
612*da0073e9SAndroid Build Coastguard Worker   output = F::grid_sample(input, grid, options);
613*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
614*da0073e9SAndroid Build Coastguard Worker       {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat);
615*da0073e9SAndroid Build Coastguard Worker 
616*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
617*da0073e9SAndroid Build Coastguard Worker 
618*da0073e9SAndroid Build Coastguard Worker   // default options (bilinear, zeros, false) same result as above
619*da0073e9SAndroid Build Coastguard Worker   output = F::grid_sample(input, grid);
620*da0073e9SAndroid Build Coastguard Worker 
621*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
622*da0073e9SAndroid Build Coastguard Worker 
623*da0073e9SAndroid Build Coastguard Worker   // nearest, zeros, true
624*da0073e9SAndroid Build Coastguard Worker   options = F::GridSampleFuncOptions()
625*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kNearest)
626*da0073e9SAndroid Build Coastguard Worker                 .padding_mode(torch::kZeros)
627*da0073e9SAndroid Build Coastguard Worker                 .align_corners(true);
628*da0073e9SAndroid Build Coastguard Worker   output = F::grid_sample(input, grid, options);
629*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
630*da0073e9SAndroid Build Coastguard Worker       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
631*da0073e9SAndroid Build Coastguard Worker 
632*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
633*da0073e9SAndroid Build Coastguard Worker 
634*da0073e9SAndroid Build Coastguard Worker   // bilinear, border, true
635*da0073e9SAndroid Build Coastguard Worker   options = F::GridSampleFuncOptions()
636*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBilinear)
637*da0073e9SAndroid Build Coastguard Worker                 .padding_mode(torch::kBorder)
638*da0073e9SAndroid Build Coastguard Worker                 .align_corners(true);
639*da0073e9SAndroid Build Coastguard Worker   output = F::grid_sample(input, grid, options);
640*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
641*da0073e9SAndroid Build Coastguard Worker       {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat);
642*da0073e9SAndroid Build Coastguard Worker 
643*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
644*da0073e9SAndroid Build Coastguard Worker 
645*da0073e9SAndroid Build Coastguard Worker   // bilinear, reflection, true
646*da0073e9SAndroid Build Coastguard Worker   options = F::GridSampleFuncOptions()
647*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBilinear)
648*da0073e9SAndroid Build Coastguard Worker                 .padding_mode(torch::kReflection)
649*da0073e9SAndroid Build Coastguard Worker                 .align_corners(true);
650*da0073e9SAndroid Build Coastguard Worker   output = F::grid_sample(input, grid, options);
651*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
652*da0073e9SAndroid Build Coastguard Worker       {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat);
653*da0073e9SAndroid Build Coastguard Worker 
654*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
655*da0073e9SAndroid Build Coastguard Worker }
656*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AffineGrid)657*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AffineGrid) {
658*da0073e9SAndroid Build Coastguard Worker   {
659*da0073e9SAndroid Build Coastguard Worker     // 2D affine.
660*da0073e9SAndroid Build Coastguard Worker     auto theta = torch::arange(1., 13).view(std::vector<int64_t>({2, 2, 3}));
661*da0073e9SAndroid Build Coastguard Worker     auto size = std::vector<int64_t>({2, 3, 2, 2});
662*da0073e9SAndroid Build Coastguard Worker     auto align_corners = true;
663*da0073e9SAndroid Build Coastguard Worker     auto output = F::affine_grid(theta, size, !align_corners);
664*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
665*da0073e9SAndroid Build Coastguard Worker         {{{{1.50, 1.50}, {2.50, 5.50}}, {{3.50, 6.50}, {4.50, 10.50}}},
666*da0073e9SAndroid Build Coastguard Worker          {{{1.50, 1.50}, {8.50, 11.50}}, {{9.50, 12.50}, {16.50, 22.50}}}});
667*da0073e9SAndroid Build Coastguard Worker     auto output_aligned = F::affine_grid(theta, size, align_corners);
668*da0073e9SAndroid Build Coastguard Worker     auto expected_aligned = torch::tensor(
669*da0073e9SAndroid Build Coastguard Worker         {{{{0.0, -3.0}, {2.0, 5.0}}, {{4.0, 7.0}, {6.0, 15.0}}},
670*da0073e9SAndroid Build Coastguard Worker          {{{-6.0, -9.0}, {8.0, 11.0}}, {{10.0, 13.0}, {24.0, 33.0}}}});
671*da0073e9SAndroid Build Coastguard Worker 
672*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
673*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output_aligned.allclose(expected_aligned));
674*da0073e9SAndroid Build Coastguard Worker   }
675*da0073e9SAndroid Build Coastguard Worker   {
676*da0073e9SAndroid Build Coastguard Worker     // 3D affine.
677*da0073e9SAndroid Build Coastguard Worker     auto theta = torch::arange(1., 13).view(std::vector<int64_t>({1, 3, 4}));
678*da0073e9SAndroid Build Coastguard Worker     auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
679*da0073e9SAndroid Build Coastguard Worker     auto align_corners = true;
680*da0073e9SAndroid Build Coastguard Worker     auto output = F::affine_grid(theta, size, !align_corners);
681*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
682*da0073e9SAndroid Build Coastguard Worker         {{{{{0.5000, -2.1667, -4.8333}, {1.5000, 2.8333, 4.1667}},
683*da0073e9SAndroid Build Coastguard Worker            {{2.5000, 3.8333, 5.1667}, {3.5000, 8.8333, 14.1667}}},
684*da0073e9SAndroid Build Coastguard Worker           {{{2.5000, 2.5000, 2.5000}, {3.5000, 7.5000, 11.5000}},
685*da0073e9SAndroid Build Coastguard Worker            {{4.5000, 8.5000, 12.5000}, {5.5000, 13.5000, 21.5000}}},
686*da0073e9SAndroid Build Coastguard Worker           {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}},
687*da0073e9SAndroid Build Coastguard Worker            {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}});
688*da0073e9SAndroid Build Coastguard Worker     auto output_aligned = F::affine_grid(theta, size, align_corners);
689*da0073e9SAndroid Build Coastguard Worker     auto expected_aligned = torch::tensor(
690*da0073e9SAndroid Build Coastguard Worker         {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}},
691*da0073e9SAndroid Build Coastguard Worker            {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}},
692*da0073e9SAndroid Build Coastguard Worker           {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}},
693*da0073e9SAndroid Build Coastguard Worker            {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}},
694*da0073e9SAndroid Build Coastguard Worker           {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}},
695*da0073e9SAndroid Build Coastguard Worker            {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}});
696*da0073e9SAndroid Build Coastguard Worker 
697*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-2));
698*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output_aligned.allclose(expected_aligned));
699*da0073e9SAndroid Build Coastguard Worker   }
700*da0073e9SAndroid Build Coastguard Worker   {
701*da0073e9SAndroid Build Coastguard Worker     auto theta = torch::empty({1, 2, 3}, torch::kDouble);
702*da0073e9SAndroid Build Coastguard Worker     auto size = std::vector<int64_t>({1, 1, 2, 2});
703*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
704*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(torch::empty({2, 2, 3}), {-1, 1, 2, 2}),
705*da0073e9SAndroid Build Coastguard Worker         "Expected non-zero, positive output size. Got [-1, 1, 2, 2]");
706*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
707*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(torch::empty({2, 2, 3}, torch::kInt), size),
708*da0073e9SAndroid Build Coastguard Worker         "Expected theta to have floating point type, but got int");
709*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
710*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta[0], size),
711*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
712*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2]. Got [2, 3].");
713*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
714*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.unsqueeze(0), size),
715*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
716*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2]. Got [1, 1, 2, 3].");
717*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
718*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.repeat({1, 2, 1}), size),
719*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
720*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2]. Got [1, 4, 3].");
721*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
722*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.repeat({1, 1, 2}), size),
723*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
724*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2]. Got [1, 2, 6].");
725*da0073e9SAndroid Build Coastguard Worker   }
726*da0073e9SAndroid Build Coastguard Worker   {
727*da0073e9SAndroid Build Coastguard Worker     auto theta = torch::empty({1, 3, 4}, torch::kDouble);
728*da0073e9SAndroid Build Coastguard Worker     auto size = std::vector<int64_t>({1, 1, 2, 2, 3});
729*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
730*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta[0], size),
731*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
732*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2, 3]. Got [3, 4].");
733*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
734*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.unsqueeze(0), size),
735*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
736*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2, 3]. Got [1, 1, 3, 4].");
737*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
738*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.repeat({1, 2, 1}), size),
739*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
740*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2, 3]. Got [1, 6, 4].");
741*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
742*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta.repeat({1, 1, 2}), size),
743*da0073e9SAndroid Build Coastguard Worker         "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
744*da0073e9SAndroid Build Coastguard Worker         "[1, 1, 2, 2, 3]. Got [1, 3, 8].");
745*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
746*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta, {1, 1, 1, 2, 2, 3}),
747*da0073e9SAndroid Build Coastguard Worker         "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
748*da0073e9SAndroid Build Coastguard Worker         "transforms, respectively. Got size [1, 1, 1, 2, 2, 3]");
749*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
750*da0073e9SAndroid Build Coastguard Worker         F::affine_grid(theta, {1, 1}),
751*da0073e9SAndroid Build Coastguard Worker         "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
752*da0073e9SAndroid Build Coastguard Worker         "transforms, respectively. Got size [1, 1]");
753*da0073e9SAndroid Build Coastguard Worker   }
754*da0073e9SAndroid Build Coastguard Worker }
755*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MultiMarginLoss)756*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MultiMarginLoss) {
757*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
758*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
759*da0073e9SAndroid Build Coastguard Worker       {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
760*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
761*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({2, 1, 0}, torch::kLong);
762*da0073e9SAndroid Build Coastguard Worker   auto output = F::multi_margin_loss(
763*da0073e9SAndroid Build Coastguard Worker       input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight));
764*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.305556}, torch::kFloat);
765*da0073e9SAndroid Build Coastguard Worker 
766*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
767*da0073e9SAndroid Build Coastguard Worker }
768*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CosineEmbeddingLoss)769*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CosineEmbeddingLoss) {
770*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}});
771*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}});
772*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({1, -1});
773*da0073e9SAndroid Build Coastguard Worker   auto output = F::cosine_embedding_loss(
774*da0073e9SAndroid Build Coastguard Worker       input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5));
775*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.1004}, torch::kFloat);
776*da0073e9SAndroid Build Coastguard Worker 
777*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-4));
778*da0073e9SAndroid Build Coastguard Worker }
779*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MultiLabelMarginLossDefaultOptions)780*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) {
781*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
782*da0073e9SAndroid Build Coastguard Worker       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
783*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
784*da0073e9SAndroid Build Coastguard Worker   auto output = F::multilabel_margin_loss(input, target);
785*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8500}, torch::kFloat);
786*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
787*da0073e9SAndroid Build Coastguard Worker   s.backward();
788*da0073e9SAndroid Build Coastguard Worker 
789*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
790*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
791*da0073e9SAndroid Build Coastguard Worker }
792*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MultiLabelMarginLossNoReduction)793*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) {
794*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
795*da0073e9SAndroid Build Coastguard Worker       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
796*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
797*da0073e9SAndroid Build Coastguard Worker   auto output = F::multilabel_margin_loss(input, target, torch::kNone);
798*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8500}, torch::kFloat);
799*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
800*da0073e9SAndroid Build Coastguard Worker   s.backward();
801*da0073e9SAndroid Build Coastguard Worker 
802*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
803*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
804*da0073e9SAndroid Build Coastguard Worker }
805*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,TripletMarginLoss)806*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, TripletMarginLoss) {
807*da0073e9SAndroid Build Coastguard Worker   auto anchor = torch::tensor({{3., 3.}}, torch::kFloat);
808*da0073e9SAndroid Build Coastguard Worker   auto positive = torch::tensor({{2., 2.}}, torch::kFloat);
809*da0073e9SAndroid Build Coastguard Worker   auto negative = torch::tensor({{0., 0.}}, torch::kFloat);
810*da0073e9SAndroid Build Coastguard Worker   auto output = F::triplet_margin_loss(
811*da0073e9SAndroid Build Coastguard Worker       anchor,
812*da0073e9SAndroid Build Coastguard Worker       positive,
813*da0073e9SAndroid Build Coastguard Worker       negative,
814*da0073e9SAndroid Build Coastguard Worker       F::TripletMarginLossFuncOptions().margin(1.0));
815*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.}, torch::kFloat);
816*da0073e9SAndroid Build Coastguard Worker 
817*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
818*da0073e9SAndroid Build Coastguard Worker }
819*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,TripletMarginWithDistanceLossDefaultParity)820*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
821*da0073e9SAndroid Build Coastguard Worker   // Check that if we use torch::pairwise_distance with the default
822*da0073e9SAndroid Build Coastguard Worker   // TripletMarginLoss options as our distance function, the outputs
823*da0073e9SAndroid Build Coastguard Worker   // are equal (i.e., equal under defaults).
824*da0073e9SAndroid Build Coastguard Worker 
825*da0073e9SAndroid Build Coastguard Worker   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
826*da0073e9SAndroid Build Coastguard Worker       torch::kSum, torch::kMean, torch::kNone};
827*da0073e9SAndroid Build Coastguard Worker   std::vector<float> margins = {0.5, 1.0, 1.5};
828*da0073e9SAndroid Build Coastguard Worker   std::vector<bool> swaps = {true, false};
829*da0073e9SAndroid Build Coastguard Worker 
830*da0073e9SAndroid Build Coastguard Worker   for (auto& reduction : reductions) {
831*da0073e9SAndroid Build Coastguard Worker     for (auto& margin : margins) {
832*da0073e9SAndroid Build Coastguard Worker       for (const auto& swap : swaps) {
833*da0073e9SAndroid Build Coastguard Worker         auto anchor = torch::randn(
834*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
835*da0073e9SAndroid Build Coastguard Worker         auto positive = torch::randn(
836*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
837*da0073e9SAndroid Build Coastguard Worker         auto negative = torch::randn(
838*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
839*da0073e9SAndroid Build Coastguard Worker 
840*da0073e9SAndroid Build Coastguard Worker         auto basicOptions = F::TripletMarginLossFuncOptions()
841*da0073e9SAndroid Build Coastguard Worker                                 .reduction(reduction)
842*da0073e9SAndroid Build Coastguard Worker                                 .margin(margin)
843*da0073e9SAndroid Build Coastguard Worker                                 .swap(swap);
844*da0073e9SAndroid Build Coastguard Worker         auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions()
845*da0073e9SAndroid Build Coastguard Worker                                    .reduction(reduction)
846*da0073e9SAndroid Build Coastguard Worker                                    .margin(margin)
847*da0073e9SAndroid Build Coastguard Worker                                    .swap(swap);
848*da0073e9SAndroid Build Coastguard Worker         TripletMarginLoss basicLoss(basicOptions);
849*da0073e9SAndroid Build Coastguard Worker         TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
850*da0073e9SAndroid Build Coastguard Worker 
851*da0073e9SAndroid Build Coastguard Worker         auto basicOutput =
852*da0073e9SAndroid Build Coastguard Worker             F::triplet_margin_loss(anchor, positive, negative, basicOptions);
853*da0073e9SAndroid Build Coastguard Worker         auto distanceOutput = F::triplet_margin_with_distance_loss(
854*da0073e9SAndroid Build Coastguard Worker             anchor, positive, negative, distanceOptions);
855*da0073e9SAndroid Build Coastguard Worker 
856*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
857*da0073e9SAndroid Build Coastguard Worker 
858*da0073e9SAndroid Build Coastguard Worker         // handle for torch::kNone reduction
859*da0073e9SAndroid Build Coastguard Worker         auto sum = distanceOutput.sum();
860*da0073e9SAndroid Build Coastguard Worker         sum.backward();
861*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
862*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(positive.sizes(), positive.grad().sizes());
863*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(negative.sizes(), negative.grad().sizes());
864*da0073e9SAndroid Build Coastguard Worker       }
865*da0073e9SAndroid Build Coastguard Worker     }
866*da0073e9SAndroid Build Coastguard Worker   }
867*da0073e9SAndroid Build Coastguard Worker }
868*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,NLLLoss)869*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, NLLLoss) {
870*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
871*da0073e9SAndroid Build Coastguard Worker       {{-0.1315, -3.1315, -2.5315},
872*da0073e9SAndroid Build Coastguard Worker        {-3.7038, -0.1038, -2.6038},
873*da0073e9SAndroid Build Coastguard Worker        {-2.3422, -1.3422, -0.4422}},
874*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
875*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({1, 0, 2}, torch::kLong);
876*da0073e9SAndroid Build Coastguard Worker   auto output = F::nll_loss(
877*da0073e9SAndroid Build Coastguard Worker       input,
878*da0073e9SAndroid Build Coastguard Worker       target,
879*da0073e9SAndroid Build Coastguard Worker       F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
880*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(2.4258, torch::kFloat);
881*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
882*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04));
883*da0073e9SAndroid Build Coastguard Worker }
884*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CrossEntropy)885*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CrossEntropy) {
886*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat);
887*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0, 1}, torch::kLong);
888*da0073e9SAndroid Build Coastguard Worker   auto output = F::cross_entropy(
889*da0073e9SAndroid Build Coastguard Worker       input,
890*da0073e9SAndroid Build Coastguard Worker       target,
891*da0073e9SAndroid Build Coastguard Worker       F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
892*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.6931, torch::kFloat);
893*da0073e9SAndroid Build Coastguard Worker 
894*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
895*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04));
896*da0073e9SAndroid Build Coastguard Worker 
897*da0073e9SAndroid Build Coastguard Worker   // label smoothing with class indices
898*da0073e9SAndroid Build Coastguard Worker   input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat);
899*da0073e9SAndroid Build Coastguard Worker   output = F::cross_entropy(
900*da0073e9SAndroid Build Coastguard Worker       input,
901*da0073e9SAndroid Build Coastguard Worker       target,
902*da0073e9SAndroid Build Coastguard Worker       F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(
903*da0073e9SAndroid Build Coastguard Worker           torch::kMean));
904*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(0.3326, torch::kFloat);
905*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
906*da0073e9SAndroid Build Coastguard Worker 
907*da0073e9SAndroid Build Coastguard Worker   // label smoothing with target probabilities
908*da0073e9SAndroid Build Coastguard Worker   target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
909*da0073e9SAndroid Build Coastguard Worker   output = F::cross_entropy(
910*da0073e9SAndroid Build Coastguard Worker       input,
911*da0073e9SAndroid Build Coastguard Worker       target,
912*da0073e9SAndroid Build Coastguard Worker       F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(
913*da0073e9SAndroid Build Coastguard Worker           torch::kMean));
914*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(0.5701, torch::kFloat);
915*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
916*da0073e9SAndroid Build Coastguard Worker }
917*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxUnpool1d)918*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxUnpool1d) {
919*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
920*da0073e9SAndroid Build Coastguard Worker       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
921*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
922*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
923*da0073e9SAndroid Build Coastguard Worker 
924*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
925*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
926*da0073e9SAndroid Build Coastguard Worker       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
927*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
928*da0073e9SAndroid Build Coastguard Worker 
929*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
930*da0073e9SAndroid Build Coastguard Worker       {{2, 4, 5}}, torch::dtype(torch::kFloat).requires_grad(true));
931*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor({{1, 3, 4}}, torch::kLong);
932*da0073e9SAndroid Build Coastguard Worker   y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
933*da0073e9SAndroid Build Coastguard Worker 
934*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
935*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
936*da0073e9SAndroid Build Coastguard Worker       y, torch::tensor({{0, 2, 0, 4, 5, 0, 0, 0, 0}}, torch::kFloat)));
937*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 9}));
938*da0073e9SAndroid Build Coastguard Worker 
939*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
940*da0073e9SAndroid Build Coastguard Worker       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
941*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
942*da0073e9SAndroid Build Coastguard Worker   y = F::max_unpool1d(
943*da0073e9SAndroid Build Coastguard Worker       x,
944*da0073e9SAndroid Build Coastguard Worker       indices,
945*da0073e9SAndroid Build Coastguard Worker       F::MaxUnpool1dFuncOptions(3).output_size(
946*da0073e9SAndroid Build Coastguard Worker           std::vector<int64_t>({1, 1, 9})));
947*da0073e9SAndroid Build Coastguard Worker 
948*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
949*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
950*da0073e9SAndroid Build Coastguard Worker       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
951*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
952*da0073e9SAndroid Build Coastguard Worker 
953*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
954*da0073e9SAndroid Build Coastguard Worker       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
955*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
956*da0073e9SAndroid Build Coastguard Worker   y = F::max_unpool1d(
957*da0073e9SAndroid Build Coastguard Worker       x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1));
958*da0073e9SAndroid Build Coastguard Worker 
959*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
960*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
961*da0073e9SAndroid Build Coastguard Worker       torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
962*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
963*da0073e9SAndroid Build Coastguard Worker }
964*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxUnpool2d)965*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxUnpool2d) {
966*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor(
967*da0073e9SAndroid Build Coastguard Worker       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
968*da0073e9SAndroid Build Coastguard Worker        {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
969*da0073e9SAndroid Build Coastguard Worker       torch::kLong);
970*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
971*da0073e9SAndroid Build Coastguard Worker       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
972*da0073e9SAndroid Build Coastguard Worker        {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
973*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
974*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_unpool2d(
975*da0073e9SAndroid Build Coastguard Worker       x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
976*da0073e9SAndroid Build Coastguard Worker 
977*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 4);
978*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
979*da0073e9SAndroid Build Coastguard Worker       y,
980*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
981*da0073e9SAndroid Build Coastguard Worker           {{{{0, 0, 0, 0, 0},
982*da0073e9SAndroid Build Coastguard Worker              {0, 6, 0, 8, 9},
983*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 0, 0},
984*da0073e9SAndroid Build Coastguard Worker              {0, 16, 0, 18, 19},
985*da0073e9SAndroid Build Coastguard Worker              {0, 21, 0, 23, 24}}},
986*da0073e9SAndroid Build Coastguard Worker            {{{0, 0, 0, 0, 0},
987*da0073e9SAndroid Build Coastguard Worker              {0, 31, 0, 33, 34},
988*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 0, 0},
989*da0073e9SAndroid Build Coastguard Worker              {0, 41, 0, 43, 44},
990*da0073e9SAndroid Build Coastguard Worker              {0, 46, 0, 48, 49}}}},
991*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
992*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
993*da0073e9SAndroid Build Coastguard Worker 
994*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor(
995*da0073e9SAndroid Build Coastguard Worker       {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
996*da0073e9SAndroid Build Coastguard Worker        {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
997*da0073e9SAndroid Build Coastguard Worker       torch::kLong);
998*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
999*da0073e9SAndroid Build Coastguard Worker       {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
1000*da0073e9SAndroid Build Coastguard Worker        {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}},
1001*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
1002*da0073e9SAndroid Build Coastguard Worker   y = F::max_unpool2d(
1003*da0073e9SAndroid Build Coastguard Worker       x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
1004*da0073e9SAndroid Build Coastguard Worker 
1005*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
1006*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1007*da0073e9SAndroid Build Coastguard Worker       y,
1008*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1009*da0073e9SAndroid Build Coastguard Worker           {{{0, 0, 0, 0, 0},
1010*da0073e9SAndroid Build Coastguard Worker             {0, 6, 0, 8, 9},
1011*da0073e9SAndroid Build Coastguard Worker             {0, 0, 0, 0, 0},
1012*da0073e9SAndroid Build Coastguard Worker             {0, 16, 0, 18, 19},
1013*da0073e9SAndroid Build Coastguard Worker             {0, 21, 0, 23, 24}},
1014*da0073e9SAndroid Build Coastguard Worker            {{0, 0, 0, 0, 0},
1015*da0073e9SAndroid Build Coastguard Worker             {0, 31, 0, 33, 34},
1016*da0073e9SAndroid Build Coastguard Worker             {0, 0, 0, 0, 0},
1017*da0073e9SAndroid Build Coastguard Worker             {0, 41, 0, 43, 44},
1018*da0073e9SAndroid Build Coastguard Worker             {0, 46, 0, 48, 49}}},
1019*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1020*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 5, 5}));
1021*da0073e9SAndroid Build Coastguard Worker }
1022*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MaxUnpool3d)1023*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MaxUnpool3d) {
1024*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1025*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
1026*da0073e9SAndroid Build Coastguard Worker       {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1027*da0073e9SAndroid Build Coastguard Worker   auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1028*da0073e9SAndroid Build Coastguard Worker 
1029*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 5);
1030*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1031*da0073e9SAndroid Build Coastguard Worker       y,
1032*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1033*da0073e9SAndroid Build Coastguard Worker           {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1034*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1035*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1036*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1037*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1038*da0073e9SAndroid Build Coastguard Worker 
1039*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor({{{{26}}}}, torch::kLong);
1040*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
1041*da0073e9SAndroid Build Coastguard Worker       {{{{26}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1042*da0073e9SAndroid Build Coastguard Worker   y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
1043*da0073e9SAndroid Build Coastguard Worker 
1044*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 4);
1045*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1046*da0073e9SAndroid Build Coastguard Worker       y,
1047*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1048*da0073e9SAndroid Build Coastguard Worker           {{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1049*da0073e9SAndroid Build Coastguard Worker             {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1050*da0073e9SAndroid Build Coastguard Worker             {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}},
1051*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1052*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
1053*da0073e9SAndroid Build Coastguard Worker }
1054*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ELU)1055*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ELU) {
1056*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1057*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1058*da0073e9SAndroid Build Coastguard Worker     for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1059*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-10.0, 10.0, size * size * size);
1060*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
1061*da0073e9SAndroid Build Coastguard Worker       auto x_bf16 =
1062*da0073e9SAndroid Build Coastguard Worker           torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16);
1063*da0073e9SAndroid Build Coastguard Worker       x_bf16.resize_({size, size, size});
1064*da0073e9SAndroid Build Coastguard Worker 
1065*da0073e9SAndroid Build Coastguard Worker       auto y_exp = torch::max(torch::zeros_like(x), x) +
1066*da0073e9SAndroid Build Coastguard Worker           torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
1067*da0073e9SAndroid Build Coastguard Worker       auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1068*da0073e9SAndroid Build Coastguard Worker       auto y_bf16 =
1069*da0073e9SAndroid Build Coastguard Worker           F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
1070*da0073e9SAndroid Build Coastguard Worker 
1071*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
1072*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1073*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
1074*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1075*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
1076*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x, y_exp));
1077*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1078*da0073e9SAndroid Build Coastguard Worker       }
1079*da0073e9SAndroid Build Coastguard Worker     }
1080*da0073e9SAndroid Build Coastguard Worker   }
1081*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::elu(torch::tensor(1.)).defined());
1082*da0073e9SAndroid Build Coastguard Worker }
1083*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SELU)1084*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SELU) {
1085*da0073e9SAndroid Build Coastguard Worker   {
1086*da0073e9SAndroid Build Coastguard Worker     const double scale = 1.0507009873554804934193349852946;
1087*da0073e9SAndroid Build Coastguard Worker     const double alpha = 1.6732632423543772848170429916717;
1088*da0073e9SAndroid Build Coastguard Worker     for (const auto inplace : {false, true}) {
1089*da0073e9SAndroid Build Coastguard Worker       auto input = torch::randn({5, 5});
1090*da0073e9SAndroid Build Coastguard Worker       auto input_bf16 = input.clone().to(torch::kBFloat16);
1091*da0073e9SAndroid Build Coastguard Worker       auto expected = scale *
1092*da0073e9SAndroid Build Coastguard Worker           (torch::max(torch::zeros_like(input), input) +
1093*da0073e9SAndroid Build Coastguard Worker            torch::min(
1094*da0073e9SAndroid Build Coastguard Worker                torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
1095*da0073e9SAndroid Build Coastguard Worker       auto output = F::selu(input, inplace);
1096*da0073e9SAndroid Build Coastguard Worker       auto output_bf16 = F::selu(input_bf16, inplace);
1097*da0073e9SAndroid Build Coastguard Worker 
1098*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(output.allclose(expected));
1099*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(output_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1100*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
1101*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(input.allclose(expected));
1102*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(input_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1103*da0073e9SAndroid Build Coastguard Worker       }
1104*da0073e9SAndroid Build Coastguard Worker     }
1105*da0073e9SAndroid Build Coastguard Worker   }
1106*da0073e9SAndroid Build Coastguard Worker   {
1107*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(0, 9, torch::kDouble).view({3, 3});
1108*da0073e9SAndroid Build Coastguard Worker     auto output = F::selu(input);
1109*da0073e9SAndroid Build Coastguard Worker     auto expected = F::selu(input, false);
1110*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
1111*da0073e9SAndroid Build Coastguard Worker   }
1112*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::selu(torch::tensor(1.)).defined());
1113*da0073e9SAndroid Build Coastguard Worker }
1114*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,GLU)1115*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, GLU) {
1116*da0073e9SAndroid Build Coastguard Worker   int64_t dim = 1;
1117*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({4, 2}, torch::requires_grad());
1118*da0073e9SAndroid Build Coastguard Worker   auto output = F::glu(input, dim);
1119*da0073e9SAndroid Build Coastguard Worker   auto input_size = input.sizes()[dim] / 2;
1120*da0073e9SAndroid Build Coastguard Worker   auto first_half = input.narrow(dim, 0, input_size);
1121*da0073e9SAndroid Build Coastguard Worker   auto second_half = input.narrow(dim, input_size, input_size);
1122*da0073e9SAndroid Build Coastguard Worker   auto expected = first_half * torch::sigmoid(second_half);
1123*da0073e9SAndroid Build Coastguard Worker 
1124*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1125*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::glu(input).allclose(expected));
1126*da0073e9SAndroid Build Coastguard Worker }
1127*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,GELU)1128*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, GELU) {
1129*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::linspace(-3.0, 3.0, 100);
1130*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
1131*da0073e9SAndroid Build Coastguard Worker   const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none"));
1132*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1133*da0073e9SAndroid Build Coastguard Worker }
1134*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,TanhGELU)1135*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, TanhGELU) {
1136*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::linspace(-3.0, 3.0, 100);
1137*da0073e9SAndroid Build Coastguard Worker   const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
1138*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
1139*da0073e9SAndroid Build Coastguard Worker   const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh"));
1140*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1141*da0073e9SAndroid Build Coastguard Worker }
1142*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Hardshrink)1143*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Hardshrink) {
1144*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1145*da0073e9SAndroid Build Coastguard Worker   for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
1146*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1147*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size}).set_requires_grad(true);
1148*da0073e9SAndroid Build Coastguard Worker     auto y = F::hardshrink(x, F::HardshrinkFuncOptions().lambda(lambda));
1149*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
1150*da0073e9SAndroid Build Coastguard Worker 
1151*da0073e9SAndroid Build Coastguard Worker     s.backward();
1152*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1153*da0073e9SAndroid Build Coastguard Worker 
1154*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1155*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1156*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x.abs() > lambda) * x;
1157*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1158*da0073e9SAndroid Build Coastguard Worker   }
1159*da0073e9SAndroid Build Coastguard Worker }
1160*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,OneHot)1161*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, OneHot) {
1162*da0073e9SAndroid Build Coastguard Worker   { // Test #1
1163*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(0, 5, torch::kLong);
1164*da0073e9SAndroid Build Coastguard Worker     auto y = F::one_hot(x % 3);
1165*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
1166*da0073e9SAndroid Build Coastguard Worker         {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong);
1167*da0073e9SAndroid Build Coastguard Worker 
1168*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 2);
1169*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
1170*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 3}));
1171*da0073e9SAndroid Build Coastguard Worker   }
1172*da0073e9SAndroid Build Coastguard Worker 
1173*da0073e9SAndroid Build Coastguard Worker   { // Test #2
1174*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(0, 5, torch::kLong);
1175*da0073e9SAndroid Build Coastguard Worker     auto y = F::one_hot(x % 3, 5);
1176*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
1177*da0073e9SAndroid Build Coastguard Worker         {{1, 0, 0, 0, 0},
1178*da0073e9SAndroid Build Coastguard Worker          {0, 1, 0, 0, 0},
1179*da0073e9SAndroid Build Coastguard Worker          {0, 0, 1, 0, 0},
1180*da0073e9SAndroid Build Coastguard Worker          {1, 0, 0, 0, 0},
1181*da0073e9SAndroid Build Coastguard Worker          {0, 1, 0, 0, 0}},
1182*da0073e9SAndroid Build Coastguard Worker         torch::kLong);
1183*da0073e9SAndroid Build Coastguard Worker 
1184*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 2);
1185*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
1186*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 5}));
1187*da0073e9SAndroid Build Coastguard Worker   }
1188*da0073e9SAndroid Build Coastguard Worker 
1189*da0073e9SAndroid Build Coastguard Worker   { // Test #3
1190*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(0, 6, torch::kLong);
1191*da0073e9SAndroid Build Coastguard Worker     auto y = F::one_hot(x.view(std::vector<int64_t>({3, 2})) % 3);
1192*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
1193*da0073e9SAndroid Build Coastguard Worker         {{{1, 0, 0}, {0, 1, 0}},
1194*da0073e9SAndroid Build Coastguard Worker          {{0, 0, 1}, {1, 0, 0}},
1195*da0073e9SAndroid Build Coastguard Worker          {{0, 1, 0}, {0, 0, 1}}},
1196*da0073e9SAndroid Build Coastguard Worker         torch::kLong);
1197*da0073e9SAndroid Build Coastguard Worker 
1198*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1199*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
1200*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({3, 2, 3}));
1201*da0073e9SAndroid Build Coastguard Worker   }
1202*da0073e9SAndroid Build Coastguard Worker }
1203*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Hardtanh)1204*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Hardtanh) {
1205*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1206*da0073e9SAndroid Build Coastguard Worker   for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
1207*da0073e9SAndroid Build Coastguard Worker     for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) {
1208*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
1209*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-10.0, 10.0, size * size * size);
1210*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
1211*da0073e9SAndroid Build Coastguard Worker         auto y_exp = (x < min_val) * min_val +
1212*da0073e9SAndroid Build Coastguard Worker             ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val;
1213*da0073e9SAndroid Build Coastguard Worker         auto y = F::hardtanh(
1214*da0073e9SAndroid Build Coastguard Worker             x,
1215*da0073e9SAndroid Build Coastguard Worker             F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace(
1216*da0073e9SAndroid Build Coastguard Worker                 inplace));
1217*da0073e9SAndroid Build Coastguard Worker 
1218*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
1219*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1220*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
1221*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
1222*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
1223*da0073e9SAndroid Build Coastguard Worker         }
1224*da0073e9SAndroid Build Coastguard Worker       }
1225*da0073e9SAndroid Build Coastguard Worker     }
1226*da0073e9SAndroid Build Coastguard Worker   }
1227*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::hardtanh(torch::tensor(1.)).defined());
1228*da0073e9SAndroid Build Coastguard Worker }
1229*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LeakyReLU)1230*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LeakyReLU) {
1231*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1232*da0073e9SAndroid Build Coastguard Worker   for (const auto negative_slope : {0.0, 0.42, 1.0}) {
1233*da0073e9SAndroid Build Coastguard Worker     for (const auto inplace : {false, true}) {
1234*da0073e9SAndroid Build Coastguard Worker       for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1235*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1236*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
1237*da0073e9SAndroid Build Coastguard Worker         auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
1238*da0073e9SAndroid Build Coastguard Worker         auto y = F::leaky_relu(
1239*da0073e9SAndroid Build Coastguard Worker             x,
1240*da0073e9SAndroid Build Coastguard Worker             F::LeakyReLUFuncOptions()
1241*da0073e9SAndroid Build Coastguard Worker                 .negative_slope(negative_slope)
1242*da0073e9SAndroid Build Coastguard Worker                 .inplace(inplace));
1243*da0073e9SAndroid Build Coastguard Worker 
1244*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
1245*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1246*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
1247*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
1248*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
1249*da0073e9SAndroid Build Coastguard Worker         }
1250*da0073e9SAndroid Build Coastguard Worker       }
1251*da0073e9SAndroid Build Coastguard Worker     }
1252*da0073e9SAndroid Build Coastguard Worker   }
1253*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::leaky_relu(torch::tensor(1.)).defined());
1254*da0073e9SAndroid Build Coastguard Worker }
1255*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LogSigmoid)1256*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LogSigmoid) {
1257*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1258*da0073e9SAndroid Build Coastguard Worker   LogSigmoid model;
1259*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1260*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size});
1261*da0073e9SAndroid Build Coastguard Worker   auto y = F::logsigmoid(x);
1262*da0073e9SAndroid Build Coastguard Worker 
1263*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1264*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1265*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::log(
1266*da0073e9SAndroid Build Coastguard Worker       torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
1267*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1268*da0073e9SAndroid Build Coastguard Worker }
1269*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,GumbelSoftmax)1270*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, GumbelSoftmax) {
1271*da0073e9SAndroid Build Coastguard Worker   // Test 1: No-options
1272*da0073e9SAndroid Build Coastguard Worker   {
1273*da0073e9SAndroid Build Coastguard Worker     auto logits = torch::randn({5});
1274*da0073e9SAndroid Build Coastguard Worker     int expected_count = 1;
1275*da0073e9SAndroid Build Coastguard Worker     auto y_draw = F::gumbel_softmax(logits);
1276*da0073e9SAndroid Build Coastguard Worker 
1277*da0073e9SAndroid Build Coastguard Worker     // All values positive
1278*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(y_draw.min().item<int>(), 0);
1279*da0073e9SAndroid Build Coastguard Worker     // Shape unchanged
1280*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1281*da0073e9SAndroid Build Coastguard Worker     // One choice per draw
1282*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
1283*da0073e9SAndroid Build Coastguard Worker         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1284*da0073e9SAndroid Build Coastguard Worker   }
1285*da0073e9SAndroid Build Coastguard Worker 
1286*da0073e9SAndroid Build Coastguard Worker   // Test 2: 1D shape, 0 and -1 dim
1287*da0073e9SAndroid Build Coastguard Worker   for (const auto dim : {0, -1}) {
1288*da0073e9SAndroid Build Coastguard Worker     auto logits = torch::randn({5});
1289*da0073e9SAndroid Build Coastguard Worker     int expected_count = 1;
1290*da0073e9SAndroid Build Coastguard Worker     auto y_draw = F::gumbel_softmax(
1291*da0073e9SAndroid Build Coastguard Worker         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim));
1292*da0073e9SAndroid Build Coastguard Worker 
1293*da0073e9SAndroid Build Coastguard Worker     // All values positive
1294*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(y_draw.min().item<int>(), 0);
1295*da0073e9SAndroid Build Coastguard Worker     // Shape unchanged
1296*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1297*da0073e9SAndroid Build Coastguard Worker     // One choice per draw
1298*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
1299*da0073e9SAndroid Build Coastguard Worker         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1300*da0073e9SAndroid Build Coastguard Worker   }
1301*da0073e9SAndroid Build Coastguard Worker 
1302*da0073e9SAndroid Build Coastguard Worker   { // Test 3: 2D shape, 1 dim
1303*da0073e9SAndroid Build Coastguard Worker     auto logits = torch::randn({5, 4});
1304*da0073e9SAndroid Build Coastguard Worker     int expected_count = 5;
1305*da0073e9SAndroid Build Coastguard Worker     auto y_draw = F::gumbel_softmax(
1306*da0073e9SAndroid Build Coastguard Worker         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1));
1307*da0073e9SAndroid Build Coastguard Worker 
1308*da0073e9SAndroid Build Coastguard Worker     // All values positive
1309*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(y_draw.min().item<int>(), 0);
1310*da0073e9SAndroid Build Coastguard Worker     // Shape unchanged
1311*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1312*da0073e9SAndroid Build Coastguard Worker     // One choice per draw
1313*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
1314*da0073e9SAndroid Build Coastguard Worker         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1315*da0073e9SAndroid Build Coastguard Worker   }
1316*da0073e9SAndroid Build Coastguard Worker 
1317*da0073e9SAndroid Build Coastguard Worker   // Test 4: 3D shape, 1 and -1 dim
1318*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1319*da0073e9SAndroid Build Coastguard Worker   int dims[] = {1, -1};
1320*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers)
1321*da0073e9SAndroid Build Coastguard Worker   int expected[] = {5 * 3, 5 * 4};
1322*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1323*da0073e9SAndroid Build Coastguard Worker     auto logits = torch::randn({5, 4, 3});
1324*da0073e9SAndroid Build Coastguard Worker     int expected_count = expected[i];
1325*da0073e9SAndroid Build Coastguard Worker     auto y_draw = F::gumbel_softmax(
1326*da0073e9SAndroid Build Coastguard Worker         logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i]));
1327*da0073e9SAndroid Build Coastguard Worker 
1328*da0073e9SAndroid Build Coastguard Worker     // All values positive
1329*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(y_draw.min().item<int>(), 0);
1330*da0073e9SAndroid Build Coastguard Worker     // Shape unchanged
1331*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1332*da0073e9SAndroid Build Coastguard Worker     // One choice per draw
1333*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
1334*da0073e9SAndroid Build Coastguard Worker         y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1335*da0073e9SAndroid Build Coastguard Worker   }
1336*da0073e9SAndroid Build Coastguard Worker 
1337*da0073e9SAndroid Build Coastguard Worker   { // Test 5: Straight through
1338*da0073e9SAndroid Build Coastguard Worker     int num_draws = 100;
1339*da0073e9SAndroid Build Coastguard Worker     auto logits = torch::tensor({{0.2, 0.8, 0.1}});
1340*da0073e9SAndroid Build Coastguard Worker     logits = logits.reshape({1, 3});
1341*da0073e9SAndroid Build Coastguard Worker     logits.requires_grad();
1342*da0073e9SAndroid Build Coastguard Worker     auto probs = logits.softmax(-1);
1343*da0073e9SAndroid Build Coastguard Worker 
1344*da0073e9SAndroid Build Coastguard Worker     auto counts = torch::zeros_like(logits);
1345*da0073e9SAndroid Build Coastguard Worker     torch::Tensor y_draw;
1346*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(num_draws)) {
1347*da0073e9SAndroid Build Coastguard Worker       (void)i; // Suppress unused variable warning
1348*da0073e9SAndroid Build Coastguard Worker       y_draw =
1349*da0073e9SAndroid Build Coastguard Worker           F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true));
1350*da0073e9SAndroid Build Coastguard Worker       counts += y_draw;
1351*da0073e9SAndroid Build Coastguard Worker     }
1352*da0073e9SAndroid Build Coastguard Worker 
1353*da0073e9SAndroid Build Coastguard Worker     // All values positive
1354*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(y_draw.min().item<int>(), 0);
1355*da0073e9SAndroid Build Coastguard Worker     // Each experiment should result in 1 draw
1356*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(counts.sum().item<int>(), num_draws);
1357*da0073e9SAndroid Build Coastguard Worker 
1358*da0073e9SAndroid Build Coastguard Worker     // Check results are asymptotically as expected
1359*da0073e9SAndroid Build Coastguard Worker     auto expected = probs * num_draws;
1360*da0073e9SAndroid Build Coastguard Worker     // ~z is approximately N(0,1) for unbiased count
1361*da0073e9SAndroid Build Coastguard Worker     auto z = (counts - expected) / (expected * (1 - probs)).sqrt();
1362*da0073e9SAndroid Build Coastguard Worker     // A (lazy) approximate 99% two-sided test:
1363*da0073e9SAndroid Build Coastguard Worker     // occurs with prob alpha~>=0.01 if unbiased
1364*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(z.abs().max().item<float>(), 2.58);
1365*da0073e9SAndroid Build Coastguard Worker   }
1366*da0073e9SAndroid Build Coastguard Worker }
1367*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Softmax)1368*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Softmax) {
1369*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1370*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-argument-comment)
1371*da0073e9SAndroid Build Coastguard Worker   auto output = F::softmax(input, /*dim=*/1);
1372*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(input), 1);
1373*da0073e9SAndroid Build Coastguard Worker 
1374*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1375*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::exp(input[i]) / sum[i];
1376*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
1377*da0073e9SAndroid Build Coastguard Worker   }
1378*da0073e9SAndroid Build Coastguard Worker }
1379*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Softmin)1380*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Softmin) {
1381*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1382*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-argument-comment)
1383*da0073e9SAndroid Build Coastguard Worker   auto output = F::softmin(input, /*dim=*/1);
1384*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(-input), 1);
1385*da0073e9SAndroid Build Coastguard Worker 
1386*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1387*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::exp(-input[i]) / sum[i];
1388*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
1389*da0073e9SAndroid Build Coastguard Worker   }
1390*da0073e9SAndroid Build Coastguard Worker }
1391*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LogSoftmax)1392*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LogSoftmax) {
1393*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1394*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-argument-comment)
1395*da0073e9SAndroid Build Coastguard Worker   auto output = F::log_softmax(input, /*dim=*/1);
1396*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(input), 1);
1397*da0073e9SAndroid Build Coastguard Worker 
1398*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1399*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::log(torch::exp(input[i]) / sum[i]);
1400*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
1401*da0073e9SAndroid Build Coastguard Worker   }
1402*da0073e9SAndroid Build Coastguard Worker }
1403*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PReLU)1404*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PReLU) {
1405*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::rand({42, 24}) * 200 - 100;
1406*da0073e9SAndroid Build Coastguard Worker   const auto w = torch::rand(24) * 200 - 100;
1407*da0073e9SAndroid Build Coastguard Worker   const auto y = F::prelu(x, w);
1408*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({42, 24}));
1409*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = (x < 0) * w * x + (x >= 0) * x;
1410*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1411*da0073e9SAndroid Build Coastguard Worker }
1412*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LayerNorm)1413*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LayerNorm) {
1414*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::randn({2, 2});
1415*da0073e9SAndroid Build Coastguard Worker   auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
1416*da0073e9SAndroid Build Coastguard Worker   auto y_exp =
1417*da0073e9SAndroid Build Coastguard Worker       torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5);
1418*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1419*da0073e9SAndroid Build Coastguard Worker }
1420*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,GroupNorm)1421*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, GroupNorm) {
1422*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::randn({2, 2});
1423*da0073e9SAndroid Build Coastguard Worker   auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
1424*da0073e9SAndroid Build Coastguard Worker   auto y_exp =
1425*da0073e9SAndroid Build Coastguard Worker       torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5);
1426*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1427*da0073e9SAndroid Build Coastguard Worker }
1428*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,LocalResponseNorm)1429*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, LocalResponseNorm) {
1430*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::arange(100, 118).resize_({3, 3, 2});
1431*da0073e9SAndroid Build Coastguard Worker   const auto y = F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
1432*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1433*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
1434*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = torch::tensor(
1435*da0073e9SAndroid Build Coastguard Worker       {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}},
1436*da0073e9SAndroid Build Coastguard Worker        {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}},
1437*da0073e9SAndroid Build Coastguard Worker        {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}},
1438*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
1439*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1440*da0073e9SAndroid Build Coastguard Worker }
1441*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Linear)1442*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Linear) {
1443*da0073e9SAndroid Build Coastguard Worker   {
1444*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1445*da0073e9SAndroid Build Coastguard Worker     const auto w = torch::arange(200., 206).resize_({3, 2});
1446*da0073e9SAndroid Build Coastguard Worker     const auto b = torch::arange(300., 303);
1447*da0073e9SAndroid Build Coastguard Worker     const auto y = F::linear(x, w, b);
1448*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1449*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1450*da0073e9SAndroid Build Coastguard Worker     const auto y_exp = torch::tensor(
1451*da0073e9SAndroid Build Coastguard Worker         {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
1452*da0073e9SAndroid Build Coastguard Worker          {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
1453*da0073e9SAndroid Build Coastguard Worker          {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
1454*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
1455*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1456*da0073e9SAndroid Build Coastguard Worker   }
1457*da0073e9SAndroid Build Coastguard Worker   {
1458*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1459*da0073e9SAndroid Build Coastguard Worker     const auto w = torch::arange(200., 206).resize_({3, 2});
1460*da0073e9SAndroid Build Coastguard Worker     const auto y = F::linear(x, w);
1461*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1462*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1463*da0073e9SAndroid Build Coastguard Worker     const auto y_exp = torch::tensor(
1464*da0073e9SAndroid Build Coastguard Worker         {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
1465*da0073e9SAndroid Build Coastguard Worker          {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
1466*da0073e9SAndroid Build Coastguard Worker          {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
1467*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
1468*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1469*da0073e9SAndroid Build Coastguard Worker   }
1470*da0073e9SAndroid Build Coastguard Worker }
1471*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Embedding)1472*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Embedding) {
1473*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1474*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::empty({10, 3});
1475*da0073e9SAndroid Build Coastguard Worker   torch::nn::init::normal_(weight);
1476*da0073e9SAndroid Build Coastguard Worker   auto y = F::embedding(input, weight);
1477*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::embedding(weight, input.contiguous(), -1, false, false);
1478*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1479*da0073e9SAndroid Build Coastguard Worker }
1480*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,EmbeddingBag)1481*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, EmbeddingBag) {
1482*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong);
1483*da0073e9SAndroid Build Coastguard Worker   auto offsets = torch::tensor({0, 4}, torch::kLong);
1484*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::empty({10, 3});
1485*da0073e9SAndroid Build Coastguard Worker   torch::nn::init::normal_(weight);
1486*da0073e9SAndroid Build Coastguard Worker   auto y = F::embedding_bag(
1487*da0073e9SAndroid Build Coastguard Worker       input,
1488*da0073e9SAndroid Build Coastguard Worker       weight,
1489*da0073e9SAndroid Build Coastguard Worker       F::EmbeddingBagFuncOptions()
1490*da0073e9SAndroid Build Coastguard Worker           .mode(torch::kSum)
1491*da0073e9SAndroid Build Coastguard Worker           .offsets(offsets)
1492*da0073e9SAndroid Build Coastguard Worker           .padding_idx(4));
1493*da0073e9SAndroid Build Coastguard Worker   auto y_exp = std::get<0>(torch::embedding_bag(
1494*da0073e9SAndroid Build Coastguard Worker       weight, input, offsets, false, 0, false, torch::Tensor(), false, 4));
1495*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1496*da0073e9SAndroid Build Coastguard Worker 
1497*da0073e9SAndroid Build Coastguard Worker   // no options test
1498*da0073e9SAndroid Build Coastguard Worker   const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1499*da0073e9SAndroid Build Coastguard Worker   auto offsets_ = torch::arange(
1500*da0073e9SAndroid Build Coastguard Worker       0,
1501*da0073e9SAndroid Build Coastguard Worker       input_.numel(),
1502*da0073e9SAndroid Build Coastguard Worker       input_.size(1),
1503*da0073e9SAndroid Build Coastguard Worker       torch::TensorOptions().dtype(torch::kLong).device(input.device()));
1504*da0073e9SAndroid Build Coastguard Worker   y = F::embedding_bag(input_, weight);
1505*da0073e9SAndroid Build Coastguard Worker   y_exp = std::get<0>(torch::embedding_bag(
1506*da0073e9SAndroid Build Coastguard Worker       weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor()));
1507*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1508*da0073e9SAndroid Build Coastguard Worker }
1509*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Bilinear)1510*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Bilinear) {
1511*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}});
1512*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor({{7, 4}, {8, 9}});
1513*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}});
1514*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::tensor({1});
1515*da0073e9SAndroid Build Coastguard Worker 
1516*da0073e9SAndroid Build Coastguard Worker   auto y_with_bias = F::bilinear(input1, input2, weight, bias);
1517*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y_with_bias.ndimension(), 2);
1518*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y_with_bias.sizes(), torch::IntArrayRef({2, 1}));
1519*da0073e9SAndroid Build Coastguard Worker   auto y_with_bias_exp = torch::tensor({{449}, {1702}}).reshape({2, 1});
1520*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_with_bias, y_with_bias_exp, 1e-4, 1e-7));
1521*da0073e9SAndroid Build Coastguard Worker 
1522*da0073e9SAndroid Build Coastguard Worker   auto y_no_bias = F::bilinear(input1, input2, weight);
1523*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y_no_bias.ndimension(), 2);
1524*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1}));
1525*da0073e9SAndroid Build Coastguard Worker   auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1});
1526*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7));
1527*da0073e9SAndroid Build Coastguard Worker 
1528*da0073e9SAndroid Build Coastguard Worker   input1 = input1.to(torch::kFloat64);
1529*da0073e9SAndroid Build Coastguard Worker   input2 = input2.to(torch::kInt32);
1530*da0073e9SAndroid Build Coastguard Worker   weight = weight.to(torch::kInt32);
1531*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1532*da0073e9SAndroid Build Coastguard Worker       F::bilinear(input1, input2, weight),
1533*da0073e9SAndroid Build Coastguard Worker       "All tensors must have the same dtype, got input1: double, input2: int, weight: int");
1534*da0073e9SAndroid Build Coastguard Worker }
1535*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Normalize)1536*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Normalize) {
1537*da0073e9SAndroid Build Coastguard Worker   const auto expected = torch::tensor(
1538*da0073e9SAndroid Build Coastguard Worker       {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000},
1539*da0073e9SAndroid Build Coastguard Worker         {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}},
1540*da0073e9SAndroid Build Coastguard Worker       torch::requires_grad().dtype(torch::kFloat));
1541*da0073e9SAndroid Build Coastguard Worker   { // Test #1
1542*da0073e9SAndroid Build Coastguard Worker     auto input = torch::tensor(
1543*da0073e9SAndroid Build Coastguard Worker         {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}},
1544*da0073e9SAndroid Build Coastguard Worker         torch::dtype(torch::kFloat).requires_grad(true));
1545*da0073e9SAndroid Build Coastguard Worker     auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1546*da0073e9SAndroid Build Coastguard Worker 
1547*da0073e9SAndroid Build Coastguard Worker     // reduce to scalar to call .backward()
1548*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = norm.sum();
1549*da0073e9SAndroid Build Coastguard Worker     s.backward();
1550*da0073e9SAndroid Build Coastguard Worker 
1551*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1552*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(input.grad().numel(), 10);
1553*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(norm, expected));
1554*da0073e9SAndroid Build Coastguard Worker   }
1555*da0073e9SAndroid Build Coastguard Worker 
1556*da0073e9SAndroid Build Coastguard Worker   { // Test #2 Check variations of optional arguments
1557*da0073e9SAndroid Build Coastguard Worker     auto input = torch::tensor(
1558*da0073e9SAndroid Build Coastguard Worker         {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat));
1559*da0073e9SAndroid Build Coastguard Worker     auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat));
1560*da0073e9SAndroid Build Coastguard Worker     // non-null output argument
1561*da0073e9SAndroid Build Coastguard Worker     F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output));
1562*da0073e9SAndroid Build Coastguard Worker     // default options
1563*da0073e9SAndroid Build Coastguard Worker     F::normalize(input);
1564*da0073e9SAndroid Build Coastguard Worker 
1565*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output, expected));
1566*da0073e9SAndroid Build Coastguard Worker   }
1567*da0073e9SAndroid Build Coastguard Worker 
1568*da0073e9SAndroid Build Coastguard Worker   { // Test #3 Base case of scalar tensor
1569*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({}, torch::requires_grad());
1570*da0073e9SAndroid Build Coastguard Worker     torch::Tensor norm =
1571*da0073e9SAndroid Build Coastguard Worker         F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1572*da0073e9SAndroid Build Coastguard Worker     norm.backward();
1573*da0073e9SAndroid Build Coastguard Worker 
1574*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(input.grad().numel(), 1);
1575*da0073e9SAndroid Build Coastguard Worker   }
1576*da0073e9SAndroid Build Coastguard Worker }
1577*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ReLU)1578*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ReLU) {
1579*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1580*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1581*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1582*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size});
1583*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1584*da0073e9SAndroid Build Coastguard Worker     auto y = F::relu(x, F::ReLUFuncOptions().inplace(inplace));
1585*da0073e9SAndroid Build Coastguard Worker 
1586*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1587*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1588*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1589*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1590*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
1591*da0073e9SAndroid Build Coastguard Worker     }
1592*da0073e9SAndroid Build Coastguard Worker 
1593*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(bugprone-argument-comment)
1594*da0073e9SAndroid Build Coastguard Worker     y = F::relu(x, /*inplace=*/inplace);
1595*da0073e9SAndroid Build Coastguard Worker 
1596*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1597*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1598*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1599*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1600*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
1601*da0073e9SAndroid Build Coastguard Worker     }
1602*da0073e9SAndroid Build Coastguard Worker   }
1603*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::relu(torch::tensor(1.)).defined());
1604*da0073e9SAndroid Build Coastguard Worker }
1605*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ReLUDefaultOptions)1606*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ReLUDefaultOptions) {
1607*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1608*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1609*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size});
1610*da0073e9SAndroid Build Coastguard Worker   auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1611*da0073e9SAndroid Build Coastguard Worker   auto y = F::relu(x);
1612*da0073e9SAndroid Build Coastguard Worker 
1613*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1614*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1615*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1616*da0073e9SAndroid Build Coastguard Worker }
1617*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ReLU6)1618*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ReLU6) {
1619*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1620*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1621*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1622*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size});
1623*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1624*da0073e9SAndroid Build Coastguard Worker     auto y = F::relu6(x, F::ReLU6FuncOptions().inplace(inplace));
1625*da0073e9SAndroid Build Coastguard Worker 
1626*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1627*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1628*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1629*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1630*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
1631*da0073e9SAndroid Build Coastguard Worker     }
1632*da0073e9SAndroid Build Coastguard Worker 
1633*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(bugprone-argument-comment)
1634*da0073e9SAndroid Build Coastguard Worker     y = F::relu6(x, /*inplace=*/inplace);
1635*da0073e9SAndroid Build Coastguard Worker 
1636*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1637*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1638*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1639*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1640*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
1641*da0073e9SAndroid Build Coastguard Worker     }
1642*da0073e9SAndroid Build Coastguard Worker   }
1643*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::relu6(torch::tensor(1.)).defined());
1644*da0073e9SAndroid Build Coastguard Worker }
1645*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ReLU6DefaultOptions)1646*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ReLU6DefaultOptions) {
1647*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1648*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1649*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size});
1650*da0073e9SAndroid Build Coastguard Worker   auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1651*da0073e9SAndroid Build Coastguard Worker   auto y = F::relu6(x);
1652*da0073e9SAndroid Build Coastguard Worker 
1653*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1654*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1655*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1656*da0073e9SAndroid Build Coastguard Worker }
1657*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,RReLU)1658*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, RReLU) {
1659*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1660*da0073e9SAndroid Build Coastguard Worker   for (const auto lower : {0.01, 0.1, 0.2}) {
1661*da0073e9SAndroid Build Coastguard Worker     for (const auto upper : {0.3, 0.4, 0.5}) {
1662*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
1663*da0073e9SAndroid Build Coastguard Worker         for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1664*da0073e9SAndroid Build Coastguard Worker           auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1665*da0073e9SAndroid Build Coastguard Worker           x.resize_({size, size, size});
1666*da0073e9SAndroid Build Coastguard Worker           auto x_copy = x.clone();
1667*da0073e9SAndroid Build Coastguard Worker           auto y = F::rrelu(
1668*da0073e9SAndroid Build Coastguard Worker               x,
1669*da0073e9SAndroid Build Coastguard Worker               F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace));
1670*da0073e9SAndroid Build Coastguard Worker           auto z =
1671*da0073e9SAndroid Build Coastguard Worker               ((x_copy >= 0) * (x_copy == y) +
1672*da0073e9SAndroid Build Coastguard Worker                (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1673*da0073e9SAndroid Build Coastguard Worker               1.0;
1674*da0073e9SAndroid Build Coastguard Worker 
1675*da0073e9SAndroid Build Coastguard Worker           ASSERT_EQ(y.ndimension(), 3);
1676*da0073e9SAndroid Build Coastguard Worker           ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1677*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1678*da0073e9SAndroid Build Coastguard Worker           if (inplace) {
1679*da0073e9SAndroid Build Coastguard Worker             ASSERT_TRUE(torch::allclose(x, y));
1680*da0073e9SAndroid Build Coastguard Worker           }
1681*da0073e9SAndroid Build Coastguard Worker         }
1682*da0073e9SAndroid Build Coastguard Worker       }
1683*da0073e9SAndroid Build Coastguard Worker     }
1684*da0073e9SAndroid Build Coastguard Worker   }
1685*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::rrelu(torch::tensor(1.)).defined());
1686*da0073e9SAndroid Build Coastguard Worker }
1687*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,RReLUDefaultOptions)1688*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, RReLUDefaultOptions) {
1689*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1690*da0073e9SAndroid Build Coastguard Worker   const auto lower = 1.0 / 8.0;
1691*da0073e9SAndroid Build Coastguard Worker   const auto upper = 1.0 / 3.0;
1692*da0073e9SAndroid Build Coastguard Worker   for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1693*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1694*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size});
1695*da0073e9SAndroid Build Coastguard Worker     auto x_copy = x.clone();
1696*da0073e9SAndroid Build Coastguard Worker     auto y = F::rrelu(x);
1697*da0073e9SAndroid Build Coastguard Worker     auto z = ((x_copy >= 0) * (x_copy == y) +
1698*da0073e9SAndroid Build Coastguard Worker               (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1699*da0073e9SAndroid Build Coastguard Worker         1.0;
1700*da0073e9SAndroid Build Coastguard Worker 
1701*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1702*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1703*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1704*da0073e9SAndroid Build Coastguard Worker   }
1705*da0073e9SAndroid Build Coastguard Worker }
1706*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CELU)1707*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CELU) {
1708*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1709*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1710*da0073e9SAndroid Build Coastguard Worker     for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
1711*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-10.0, 10.0, size * size * size);
1712*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
1713*da0073e9SAndroid Build Coastguard Worker       auto x_bf16 = x.clone().to(torch::kBFloat16);
1714*da0073e9SAndroid Build Coastguard Worker       auto y_exp = torch::max(torch::zeros_like(x), x) +
1715*da0073e9SAndroid Build Coastguard Worker           torch::min(torch::zeros_like(x),
1716*da0073e9SAndroid Build Coastguard Worker                      alpha * (torch::exp(x / alpha) - 1.0));
1717*da0073e9SAndroid Build Coastguard Worker       auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1718*da0073e9SAndroid Build Coastguard Worker       auto y_bf16 =
1719*da0073e9SAndroid Build Coastguard Worker           F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1720*da0073e9SAndroid Build Coastguard Worker 
1721*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
1722*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1723*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
1724*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1725*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
1726*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x, y_exp));
1727*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1728*da0073e9SAndroid Build Coastguard Worker       }
1729*da0073e9SAndroid Build Coastguard Worker     }
1730*da0073e9SAndroid Build Coastguard Worker   }
1731*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::celu(torch::tensor(1.)).defined());
1732*da0073e9SAndroid Build Coastguard Worker }
1733*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CELUDefaultOptions)1734*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CELUDefaultOptions) {
1735*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1736*da0073e9SAndroid Build Coastguard Worker   const auto alpha = 1.0;
1737*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1738*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size});
1739*da0073e9SAndroid Build Coastguard Worker   auto x_bf16 = x.clone().to(torch::kBFloat16);
1740*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::max(torch::zeros_like(x), x) +
1741*da0073e9SAndroid Build Coastguard Worker       torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
1742*da0073e9SAndroid Build Coastguard Worker   auto y = F::celu(x);
1743*da0073e9SAndroid Build Coastguard Worker   auto y_bf16 = F::celu(x_bf16);
1744*da0073e9SAndroid Build Coastguard Worker 
1745*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1746*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1747*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1748*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1749*da0073e9SAndroid Build Coastguard Worker }
1750*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PixelShuffle)1751*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PixelShuffle) {
1752*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
1753*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 19}, {-1, 2}},
1754*da0073e9SAndroid Build Coastguard Worker         {{7, 14}, {-3, 1}},
1755*da0073e9SAndroid Build Coastguard Worker         {{0, -2}, {-12, 14}},
1756*da0073e9SAndroid Build Coastguard Worker         {{-15, 0}, {-3, 9}}}},
1757*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
1758*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::tensor(
1759*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1760*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
1761*da0073e9SAndroid Build Coastguard Worker   auto y = F::pixel_shuffle(x, 2);
1762*da0073e9SAndroid Build Coastguard Worker 
1763*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
1764*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
1765*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.allclose(y_exp));
1766*da0073e9SAndroid Build Coastguard Worker }
1767*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PixelUnshuffle)1768*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PixelUnshuffle) {
1769*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
1770*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1771*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
1772*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::tensor(
1773*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 19}, {-1, 2}},
1774*da0073e9SAndroid Build Coastguard Worker         {{7, 14}, {-3, 1}},
1775*da0073e9SAndroid Build Coastguard Worker         {{0, -2}, {-12, 14}},
1776*da0073e9SAndroid Build Coastguard Worker         {{-15, 0}, {-3, 9}}}},
1777*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
1778*da0073e9SAndroid Build Coastguard Worker   auto y = F::pixel_unshuffle(x, 2);
1779*da0073e9SAndroid Build Coastguard Worker 
1780*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
1781*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
1782*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.allclose(y_exp));
1783*da0073e9SAndroid Build Coastguard Worker }
1784*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Softplus)1785*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Softplus) {
1786*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1787*da0073e9SAndroid Build Coastguard Worker   for (const auto beta : {0.5, 1.0, 2.0}) {
1788*da0073e9SAndroid Build Coastguard Worker     for (const auto threshold : {1.0, 3.0, 5.0}) {
1789*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-3.0, 3.0, 61);
1790*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
1791*da0073e9SAndroid Build Coastguard Worker       auto y_exp =
1792*da0073e9SAndroid Build Coastguard Worker           (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1793*da0073e9SAndroid Build Coastguard Worker           (x > threshold) * x;
1794*da0073e9SAndroid Build Coastguard Worker       auto y = F::softplus(
1795*da0073e9SAndroid Build Coastguard Worker           x, F::SoftplusFuncOptions().beta(beta).threshold(threshold));
1796*da0073e9SAndroid Build Coastguard Worker 
1797*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
1798*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1799*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
1800*da0073e9SAndroid Build Coastguard Worker     }
1801*da0073e9SAndroid Build Coastguard Worker   }
1802*da0073e9SAndroid Build Coastguard Worker }
1803*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SoftplusDefaultOptions)1804*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SoftplusDefaultOptions) {
1805*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1806*da0073e9SAndroid Build Coastguard Worker   const auto beta = 1.0;
1807*da0073e9SAndroid Build Coastguard Worker   const auto threshold = 20.0;
1808*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-3.0, 3.0, 61);
1809*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size});
1810*da0073e9SAndroid Build Coastguard Worker   auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1811*da0073e9SAndroid Build Coastguard Worker       (x > threshold) * x;
1812*da0073e9SAndroid Build Coastguard Worker   auto y = F::softplus(x);
1813*da0073e9SAndroid Build Coastguard Worker 
1814*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1815*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1816*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1817*da0073e9SAndroid Build Coastguard Worker }
1818*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Fold)1819*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Fold) {
1820*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::kDouble);
1821*da0073e9SAndroid Build Coastguard Worker   auto output = F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
1822*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1823*da0073e9SAndroid Build Coastguard Worker       {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1824*da0073e9SAndroid Build Coastguard Worker         {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1825*da0073e9SAndroid Build Coastguard Worker         {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1826*da0073e9SAndroid Build Coastguard Worker       torch::kDouble);
1827*da0073e9SAndroid Build Coastguard Worker 
1828*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1829*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1830*da0073e9SAndroid Build Coastguard Worker }
1831*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Unfold)1832*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Unfold) {
1833*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3});
1834*da0073e9SAndroid Build Coastguard Worker   auto output =
1835*da0073e9SAndroid Build Coastguard Worker       F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
1836*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1837*da0073e9SAndroid Build Coastguard Worker       {{{0.0, 0.0, 0.0, 4.0},
1838*da0073e9SAndroid Build Coastguard Worker         {0.0, 0.0, 3.0, 5.0},
1839*da0073e9SAndroid Build Coastguard Worker         {0.0, 1.0, 0.0, 0.0},
1840*da0073e9SAndroid Build Coastguard Worker         {0.0, 2.0, 0.0, 0.0},
1841*da0073e9SAndroid Build Coastguard Worker         {0.0, 0.0, 0.0, 10.0},
1842*da0073e9SAndroid Build Coastguard Worker         {0.0, 0.0, 9.0, 11.0},
1843*da0073e9SAndroid Build Coastguard Worker         {0.0, 7.0, 0.0, 0.0},
1844*da0073e9SAndroid Build Coastguard Worker         {6.0, 8.0, 0.0, 0.0}}},
1845*da0073e9SAndroid Build Coastguard Worker       torch::kDouble);
1846*da0073e9SAndroid Build Coastguard Worker 
1847*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1848*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1849*da0073e9SAndroid Build Coastguard Worker }
1850*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Softshrink)1851*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Softshrink) {
1852*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1853*da0073e9SAndroid Build Coastguard Worker   for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1854*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
1855*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size}).set_requires_grad(true);
1856*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(bugprone-argument-comment)
1857*da0073e9SAndroid Build Coastguard Worker     auto y = F::softshrink(x, /*lambda=*/lambda);
1858*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
1859*da0073e9SAndroid Build Coastguard Worker 
1860*da0073e9SAndroid Build Coastguard Worker     s.backward();
1861*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1862*da0073e9SAndroid Build Coastguard Worker 
1863*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
1864*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1865*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1866*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1867*da0073e9SAndroid Build Coastguard Worker   }
1868*da0073e9SAndroid Build Coastguard Worker }
1869*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,SoftshrinkDefaultOptions)1870*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, SoftshrinkDefaultOptions) {
1871*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1872*da0073e9SAndroid Build Coastguard Worker   const auto lambda = 0.5;
1873*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
1874*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size}).set_requires_grad(true);
1875*da0073e9SAndroid Build Coastguard Worker   auto y = F::softshrink(x);
1876*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1877*da0073e9SAndroid Build Coastguard Worker 
1878*da0073e9SAndroid Build Coastguard Worker   s.backward();
1879*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1880*da0073e9SAndroid Build Coastguard Worker 
1881*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1882*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1883*da0073e9SAndroid Build Coastguard Worker   auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1884*da0073e9SAndroid Build Coastguard Worker }
1885*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Softsign)1886*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Softsign) {
1887*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
1888*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x / (1 + x.abs());
1889*da0073e9SAndroid Build Coastguard Worker   auto y = F::softsign(x);
1890*da0073e9SAndroid Build Coastguard Worker 
1891*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1892*da0073e9SAndroid Build Coastguard Worker }
1893*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Mish)1894*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Mish) {
1895*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
1896*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x * x.exp().log1p().tanh();
1897*da0073e9SAndroid Build Coastguard Worker   auto y = F::mish(x);
1898*da0073e9SAndroid Build Coastguard Worker 
1899*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1900*da0073e9SAndroid Build Coastguard Worker }
1901*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Tanhshrink)1902*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Tanhshrink) {
1903*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
1904*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x - x.tanh();
1905*da0073e9SAndroid Build Coastguard Worker   auto y = F::tanhshrink(x);
1906*da0073e9SAndroid Build Coastguard Worker 
1907*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1908*da0073e9SAndroid Build Coastguard Worker }
1909*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Threshold)1910*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Threshold) {
1911*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
1912*da0073e9SAndroid Build Coastguard Worker   for (const auto threshold : {0.5, 1.0, 2.0}) {
1913*da0073e9SAndroid Build Coastguard Worker     for (const auto value : {0.5, 1.0, 2.0}) {
1914*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
1915*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-3.0, 3.0, 61);
1916*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
1917*da0073e9SAndroid Build Coastguard Worker         auto y_exp = (x <= threshold) * value + (x > threshold) * x;
1918*da0073e9SAndroid Build Coastguard Worker         auto y = F::threshold(
1919*da0073e9SAndroid Build Coastguard Worker             x, F::ThresholdFuncOptions(threshold, value).inplace(inplace));
1920*da0073e9SAndroid Build Coastguard Worker 
1921*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
1922*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1923*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
1924*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
1925*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
1926*da0073e9SAndroid Build Coastguard Worker         }
1927*da0073e9SAndroid Build Coastguard Worker       }
1928*da0073e9SAndroid Build Coastguard Worker     }
1929*da0073e9SAndroid Build Coastguard Worker   }
1930*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5))
1931*da0073e9SAndroid Build Coastguard Worker                   .defined());
1932*da0073e9SAndroid Build Coastguard Worker }
1933*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm1d)1934*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm1d) {
1935*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
1936*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
1937*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
1938*da0073e9SAndroid Build Coastguard Worker 
1939*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, 5});
1940*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(5);
1941*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(5);
1942*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::ones({num_features});
1943*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::zeros({num_features});
1944*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(
1945*da0073e9SAndroid Build Coastguard Worker       input,
1946*da0073e9SAndroid Build Coastguard Worker       mean,
1947*da0073e9SAndroid Build Coastguard Worker       variance,
1948*da0073e9SAndroid Build Coastguard Worker       F::BatchNormFuncOptions()
1949*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
1950*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
1951*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
1952*da0073e9SAndroid Build Coastguard Worker           .eps(eps)
1953*da0073e9SAndroid Build Coastguard Worker           .training(false));
1954*da0073e9SAndroid Build Coastguard Worker   auto expected = (input - mean) / torch::sqrt(variance + eps);
1955*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1956*da0073e9SAndroid Build Coastguard Worker }
1957*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm1dDefaultOptions)1958*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
1959*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, 5});
1960*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(5);
1961*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(5);
1962*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(input, mean, variance);
1963*da0073e9SAndroid Build Coastguard Worker   auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
1964*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1965*da0073e9SAndroid Build Coastguard Worker }
1966*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm2d)1967*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm2d) {
1968*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
1969*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
1970*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
1971*da0073e9SAndroid Build Coastguard Worker 
1972*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, num_features, 4, 4});
1973*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(num_features);
1974*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(num_features);
1975*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::ones({num_features});
1976*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::zeros({num_features});
1977*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(
1978*da0073e9SAndroid Build Coastguard Worker       input,
1979*da0073e9SAndroid Build Coastguard Worker       mean,
1980*da0073e9SAndroid Build Coastguard Worker       variance,
1981*da0073e9SAndroid Build Coastguard Worker       F::BatchNormFuncOptions()
1982*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
1983*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
1984*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
1985*da0073e9SAndroid Build Coastguard Worker           .eps(eps)
1986*da0073e9SAndroid Build Coastguard Worker           .training(false));
1987*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::transpose(
1988*da0073e9SAndroid Build Coastguard Worker       (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
1989*da0073e9SAndroid Build Coastguard Worker       1,
1990*da0073e9SAndroid Build Coastguard Worker       3);
1991*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1992*da0073e9SAndroid Build Coastguard Worker }
1993*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm2dDefaultOptions)1994*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) {
1995*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
1996*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
1997*da0073e9SAndroid Build Coastguard Worker 
1998*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, num_features, 4, 4});
1999*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(num_features);
2000*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(num_features);
2001*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(input, mean, variance);
2002*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::transpose(
2003*da0073e9SAndroid Build Coastguard Worker       (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
2004*da0073e9SAndroid Build Coastguard Worker       1,
2005*da0073e9SAndroid Build Coastguard Worker       3);
2006*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2007*da0073e9SAndroid Build Coastguard Worker }
2008*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm3d)2009*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm3d) {
2010*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2011*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
2012*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
2013*da0073e9SAndroid Build Coastguard Worker 
2014*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, num_features, 2, 2, 2});
2015*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(num_features);
2016*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(num_features);
2017*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::ones({num_features});
2018*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::zeros({num_features});
2019*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(
2020*da0073e9SAndroid Build Coastguard Worker       input,
2021*da0073e9SAndroid Build Coastguard Worker       mean,
2022*da0073e9SAndroid Build Coastguard Worker       variance,
2023*da0073e9SAndroid Build Coastguard Worker       F::BatchNormFuncOptions()
2024*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
2025*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
2026*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
2027*da0073e9SAndroid Build Coastguard Worker           .eps(eps)
2028*da0073e9SAndroid Build Coastguard Worker           .training(false));
2029*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::transpose(
2030*da0073e9SAndroid Build Coastguard Worker       (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2031*da0073e9SAndroid Build Coastguard Worker       1,
2032*da0073e9SAndroid Build Coastguard Worker       4);
2033*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2034*da0073e9SAndroid Build Coastguard Worker }
2035*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BatchNorm3dDefaultOptions)2036*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
2037*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2038*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
2039*da0073e9SAndroid Build Coastguard Worker 
2040*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, num_features, 2, 2, 2});
2041*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::randn(num_features);
2042*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::rand(num_features);
2043*da0073e9SAndroid Build Coastguard Worker   auto output = F::batch_norm(input, mean, variance);
2044*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::transpose(
2045*da0073e9SAndroid Build Coastguard Worker       (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
2046*da0073e9SAndroid Build Coastguard Worker       1,
2047*da0073e9SAndroid Build Coastguard Worker       4);
2048*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2049*da0073e9SAndroid Build Coastguard Worker }
2050*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm1d)2051*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm1d) {
2052*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2053*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
2054*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
2055*da0073e9SAndroid Build Coastguard Worker 
2056*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(40.).view({2, 5, 4});
2057*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::arange(5.);
2058*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::arange(5.);
2059*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange((double)num_features);
2060*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::arange((double)num_features);
2061*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(
2062*da0073e9SAndroid Build Coastguard Worker       input,
2063*da0073e9SAndroid Build Coastguard Worker       F::InstanceNormFuncOptions()
2064*da0073e9SAndroid Build Coastguard Worker           .running_mean(mean)
2065*da0073e9SAndroid Build Coastguard Worker           .running_var(variance)
2066*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
2067*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
2068*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
2069*da0073e9SAndroid Build Coastguard Worker           .eps(eps));
2070*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2071*da0073e9SAndroid Build Coastguard Worker       {{{0.0000, 0.0000, 0.0000, 0.0000},
2072*da0073e9SAndroid Build Coastguard Worker         {-0.3416, 0.5528, 1.4472, 2.3416},
2073*da0073e9SAndroid Build Coastguard Worker         {-0.6833, 1.1056, 2.8944, 4.6833},
2074*da0073e9SAndroid Build Coastguard Worker         {-1.0249, 1.6584, 4.3416, 7.0249},
2075*da0073e9SAndroid Build Coastguard Worker         {-1.3665, 2.2112, 5.7888, 9.3665}},
2076*da0073e9SAndroid Build Coastguard Worker        {{0.0000, 0.0000, 0.0000, 0.0000},
2077*da0073e9SAndroid Build Coastguard Worker         {-0.3416, 0.5528, 1.4472, 2.3416},
2078*da0073e9SAndroid Build Coastguard Worker         {-0.6833, 1.1056, 2.8944, 4.6833},
2079*da0073e9SAndroid Build Coastguard Worker         {-1.0249, 1.6584, 4.3416, 7.0249},
2080*da0073e9SAndroid Build Coastguard Worker         {-1.3665, 2.2112, 5.7888, 9.3665}}});
2081*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2082*da0073e9SAndroid Build Coastguard Worker }
2083*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm1dDefaultOptions)2084*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
2085*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(40.).view({2, 5, 4});
2086*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(input);
2087*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2088*da0073e9SAndroid Build Coastguard Worker       {{{-1.3416, -0.4472, 0.4472, 1.3416},
2089*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2090*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2091*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2092*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416}},
2093*da0073e9SAndroid Build Coastguard Worker        {{-1.3416, -0.4472, 0.4472, 1.3416},
2094*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2095*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2096*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416},
2097*da0073e9SAndroid Build Coastguard Worker         {-1.3416, -0.4472, 0.4472, 1.3416}}});
2098*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2099*da0073e9SAndroid Build Coastguard Worker }
2100*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm2d)2101*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm2d) {
2102*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2103*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
2104*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
2105*da0073e9SAndroid Build Coastguard Worker 
2106*da0073e9SAndroid Build Coastguard Worker   auto input =
2107*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2108*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::arange((double)num_features);
2109*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::arange((double)num_features);
2110*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange((double)num_features);
2111*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::arange((double)num_features);
2112*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(
2113*da0073e9SAndroid Build Coastguard Worker       input,
2114*da0073e9SAndroid Build Coastguard Worker       F::InstanceNormFuncOptions()
2115*da0073e9SAndroid Build Coastguard Worker           .running_mean(mean)
2116*da0073e9SAndroid Build Coastguard Worker           .running_var(variance)
2117*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
2118*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
2119*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
2120*da0073e9SAndroid Build Coastguard Worker           .eps(eps));
2121*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2122*da0073e9SAndroid Build Coastguard Worker       {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2123*da0073e9SAndroid Build Coastguard Worker         {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2124*da0073e9SAndroid Build Coastguard Worker         {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2125*da0073e9SAndroid Build Coastguard Worker         {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2126*da0073e9SAndroid Build Coastguard Worker         {{-1.3665, 2.2112}, {5.7888, 9.3665}}},
2127*da0073e9SAndroid Build Coastguard Worker        {{{0.0000, 0.0000}, {0.0000, 0.0000}},
2128*da0073e9SAndroid Build Coastguard Worker         {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2129*da0073e9SAndroid Build Coastguard Worker         {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2130*da0073e9SAndroid Build Coastguard Worker         {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2131*da0073e9SAndroid Build Coastguard Worker         {{-1.3665, 2.2112}, {5.7888, 9.3665}}}});
2132*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2133*da0073e9SAndroid Build Coastguard Worker }
2134*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm2dDefaultOptions)2135*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
2136*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2137*da0073e9SAndroid Build Coastguard Worker 
2138*da0073e9SAndroid Build Coastguard Worker   auto input =
2139*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2140*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(input);
2141*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2142*da0073e9SAndroid Build Coastguard Worker       {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2143*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2144*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2145*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2146*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
2147*da0073e9SAndroid Build Coastguard Worker        {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2148*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2149*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2150*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2151*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
2152*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2153*da0073e9SAndroid Build Coastguard Worker }
2154*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm3d)2155*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm3d) {
2156*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2157*da0073e9SAndroid Build Coastguard Worker   double eps = 1e-05;
2158*da0073e9SAndroid Build Coastguard Worker   double momentum = 0.1;
2159*da0073e9SAndroid Build Coastguard Worker 
2160*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2161*da0073e9SAndroid Build Coastguard Worker                    .view({2, num_features, 2, 2, 2});
2162*da0073e9SAndroid Build Coastguard Worker   auto mean = torch::arange((double)num_features);
2163*da0073e9SAndroid Build Coastguard Worker   auto variance = torch::arange((double)num_features);
2164*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange((double)num_features);
2165*da0073e9SAndroid Build Coastguard Worker   auto bias = torch::arange((double)num_features);
2166*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(
2167*da0073e9SAndroid Build Coastguard Worker       input,
2168*da0073e9SAndroid Build Coastguard Worker       F::InstanceNormFuncOptions()
2169*da0073e9SAndroid Build Coastguard Worker           .running_mean(mean)
2170*da0073e9SAndroid Build Coastguard Worker           .running_var(variance)
2171*da0073e9SAndroid Build Coastguard Worker           .weight(weight)
2172*da0073e9SAndroid Build Coastguard Worker           .bias(bias)
2173*da0073e9SAndroid Build Coastguard Worker           .momentum(momentum)
2174*da0073e9SAndroid Build Coastguard Worker           .eps(eps));
2175*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2176*da0073e9SAndroid Build Coastguard Worker       {{{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2177*da0073e9SAndroid Build Coastguard Worker          {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2178*da0073e9SAndroid Build Coastguard Worker         {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2179*da0073e9SAndroid Build Coastguard Worker          {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2180*da0073e9SAndroid Build Coastguard Worker         {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2181*da0073e9SAndroid Build Coastguard Worker          {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2182*da0073e9SAndroid Build Coastguard Worker         {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2183*da0073e9SAndroid Build Coastguard Worker          {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2184*da0073e9SAndroid Build Coastguard Worker         {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2185*da0073e9SAndroid Build Coastguard Worker          {{4.8729, 6.6186}, {8.3644, 10.1101}}}},
2186*da0073e9SAndroid Build Coastguard Worker        {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2187*da0073e9SAndroid Build Coastguard Worker          {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2188*da0073e9SAndroid Build Coastguard Worker         {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2189*da0073e9SAndroid Build Coastguard Worker          {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2190*da0073e9SAndroid Build Coastguard Worker         {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2191*da0073e9SAndroid Build Coastguard Worker          {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2192*da0073e9SAndroid Build Coastguard Worker         {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2193*da0073e9SAndroid Build Coastguard Worker          {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2194*da0073e9SAndroid Build Coastguard Worker         {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2195*da0073e9SAndroid Build Coastguard Worker          {{4.8729, 6.6186}, {8.3644, 10.1101}}}}});
2196*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2197*da0073e9SAndroid Build Coastguard Worker }
2198*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,InstanceNorm3dDefaultOptions)2199*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
2200*da0073e9SAndroid Build Coastguard Worker   int num_features = 5;
2201*da0073e9SAndroid Build Coastguard Worker 
2202*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2203*da0073e9SAndroid Build Coastguard Worker                    .view({2, num_features, 2, 2, 2});
2204*da0073e9SAndroid Build Coastguard Worker   auto output = F::instance_norm(input);
2205*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2206*da0073e9SAndroid Build Coastguard Worker       {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2207*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2208*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2209*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2210*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2211*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2212*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2213*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2214*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2215*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
2216*da0073e9SAndroid Build Coastguard Worker        {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2217*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2218*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2219*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2220*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2221*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2222*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2223*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2224*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2225*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
2226*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 2e-04));
2227*da0073e9SAndroid Build Coastguard Worker }
2228*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Interpolate)2229*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Interpolate) {
2230*da0073e9SAndroid Build Coastguard Worker   {
2231*da0073e9SAndroid Build Coastguard Worker     // 1D interpolation
2232*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 2});
2233*da0073e9SAndroid Build Coastguard Worker     auto options = F::InterpolateFuncOptions()
2234*da0073e9SAndroid Build Coastguard Worker                        .size(std::vector<int64_t>({4}))
2235*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kNearest);
2236*da0073e9SAndroid Build Coastguard Worker     auto output = F::interpolate(input, options);
2237*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::ones({1, 1, 4});
2238*da0073e9SAndroid Build Coastguard Worker 
2239*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
2240*da0073e9SAndroid Build Coastguard Worker   }
2241*da0073e9SAndroid Build Coastguard Worker   {
2242*da0073e9SAndroid Build Coastguard Worker     // 2D interpolation
2243*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
2244*da0073e9SAndroid Build Coastguard Worker       // test float scale factor up & down sampling
2245*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2246*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2, 2});
2247*da0073e9SAndroid Build Coastguard Worker         auto options =
2248*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions()
2249*da0073e9SAndroid Build Coastguard Worker                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
2250*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBilinear)
2251*da0073e9SAndroid Build Coastguard Worker                 .align_corners(align_corners);
2252*da0073e9SAndroid Build Coastguard Worker         auto output = F::interpolate(input, options);
2253*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
2254*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2255*da0073e9SAndroid Build Coastguard Worker         auto expected = torch::ones({1, 1, expected_size, expected_size});
2256*da0073e9SAndroid Build Coastguard Worker 
2257*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
2258*da0073e9SAndroid Build Coastguard Worker       }
2259*da0073e9SAndroid Build Coastguard Worker     }
2260*da0073e9SAndroid Build Coastguard Worker   }
2261*da0073e9SAndroid Build Coastguard Worker   {
2262*da0073e9SAndroid Build Coastguard Worker     // 3D interpolation
2263*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
2264*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2265*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2, 2, 2});
2266*da0073e9SAndroid Build Coastguard Worker         auto options = F::InterpolateFuncOptions()
2267*da0073e9SAndroid Build Coastguard Worker                            .scale_factor(std::vector<double>(
2268*da0073e9SAndroid Build Coastguard Worker                                {scale_factor, scale_factor, scale_factor}))
2269*da0073e9SAndroid Build Coastguard Worker                            .mode(torch::kTrilinear)
2270*da0073e9SAndroid Build Coastguard Worker                            .align_corners(align_corners);
2271*da0073e9SAndroid Build Coastguard Worker         auto output = F::interpolate(input, options);
2272*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
2273*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2274*da0073e9SAndroid Build Coastguard Worker         auto expected =
2275*da0073e9SAndroid Build Coastguard Worker             torch::ones({1, 1, expected_size, expected_size, expected_size});
2276*da0073e9SAndroid Build Coastguard Worker 
2277*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
2278*da0073e9SAndroid Build Coastguard Worker       }
2279*da0073e9SAndroid Build Coastguard Worker     }
2280*da0073e9SAndroid Build Coastguard Worker   }
2281*da0073e9SAndroid Build Coastguard Worker   {
2282*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2283*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2284*da0073e9SAndroid Build Coastguard Worker             torch::randn({1}),
2285*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions().size(std::vector<int64_t>({1}))),
2286*da0073e9SAndroid Build Coastguard Worker         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 1D) ");
2287*da0073e9SAndroid Build Coastguard Worker   }
2288*da0073e9SAndroid Build Coastguard Worker   {
2289*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 2, 2});
2290*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2291*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2292*da0073e9SAndroid Build Coastguard Worker             input[0],
2293*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions().size(std::vector<int64_t>({4, 4}))),
2294*da0073e9SAndroid Build Coastguard Worker         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) "
2295*da0073e9SAndroid Build Coastguard Worker         "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2296*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2297*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2298*da0073e9SAndroid Build Coastguard Worker             torch::reshape(input, {1, 1, 1, 3, 2, 2}),
2299*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions().size(
2300*da0073e9SAndroid Build Coastguard Worker                 std::vector<int64_t>({1, 1, 1, 3, 4, 4}))),
2301*da0073e9SAndroid Build Coastguard Worker         "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) "
2302*da0073e9SAndroid Build Coastguard Worker         "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2303*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2304*da0073e9SAndroid Build Coastguard Worker         F::interpolate(input, F::InterpolateFuncOptions()),
2305*da0073e9SAndroid Build Coastguard Worker         "either size or scale_factor should be defined");
2306*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2307*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2308*da0073e9SAndroid Build Coastguard Worker             input,
2309*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions()
2310*da0073e9SAndroid Build Coastguard Worker                 .size(std::vector<int64_t>({3, 4, 4}))
2311*da0073e9SAndroid Build Coastguard Worker                 .scale_factor(std::vector<double>({0.5}))),
2312*da0073e9SAndroid Build Coastguard Worker         "only one of size or scale_factor should be defined");
2313*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2314*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2315*da0073e9SAndroid Build Coastguard Worker             input,
2316*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions().scale_factor(
2317*da0073e9SAndroid Build Coastguard Worker                 std::vector<double>({3, 2}))),
2318*da0073e9SAndroid Build Coastguard Worker         "scale_factor shape must match input shape. "
2319*da0073e9SAndroid Build Coastguard Worker         "Input is 1D, scale_factor size is [3, 2]");
2320*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2321*da0073e9SAndroid Build Coastguard Worker         F::interpolate(
2322*da0073e9SAndroid Build Coastguard Worker             input,
2323*da0073e9SAndroid Build Coastguard Worker             F::InterpolateFuncOptions()
2324*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kNearest)
2325*da0073e9SAndroid Build Coastguard Worker                 .align_corners(true)),
2326*da0073e9SAndroid Build Coastguard Worker         "align_corners option can only be set with the "
2327*da0073e9SAndroid Build Coastguard Worker         "interpolating modes: linear | bilinear | bicubic | trilinear");
2328*da0073e9SAndroid Build Coastguard Worker   }
2329*da0073e9SAndroid Build Coastguard Worker   {
2330*da0073e9SAndroid Build Coastguard Worker     auto tensor = torch::rand({2, 3, 32, 32});
2331*da0073e9SAndroid Build Coastguard Worker     std::vector<int64_t> osize = {8, 10};
2332*da0073e9SAndroid Build Coastguard Worker     auto expected =
2333*da0073e9SAndroid Build Coastguard Worker         at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt);
2334*da0073e9SAndroid Build Coastguard Worker 
2335*da0073e9SAndroid Build Coastguard Worker     auto options = F::InterpolateFuncOptions()
2336*da0073e9SAndroid Build Coastguard Worker                        .size(osize)
2337*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kNearestExact)
2338*da0073e9SAndroid Build Coastguard Worker                        .align_corners(false);
2339*da0073e9SAndroid Build Coastguard Worker     auto output = F::interpolate(tensor, options);
2340*da0073e9SAndroid Build Coastguard Worker 
2341*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
2342*da0073e9SAndroid Build Coastguard Worker   }
2343*da0073e9SAndroid Build Coastguard Worker   {
2344*da0073e9SAndroid Build Coastguard Worker     auto tensor = torch::rand({2, 3, 32, 32});
2345*da0073e9SAndroid Build Coastguard Worker     std::vector<int64_t> osize = {8, 10};
2346*da0073e9SAndroid Build Coastguard Worker     auto expected = at::native::_upsample_bilinear2d_aa(
2347*da0073e9SAndroid Build Coastguard Worker         tensor, osize, false, torch::nullopt);
2348*da0073e9SAndroid Build Coastguard Worker 
2349*da0073e9SAndroid Build Coastguard Worker     auto options = F::InterpolateFuncOptions()
2350*da0073e9SAndroid Build Coastguard Worker                        .size(osize)
2351*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kBilinear)
2352*da0073e9SAndroid Build Coastguard Worker                        .align_corners(false)
2353*da0073e9SAndroid Build Coastguard Worker                        .antialias(true);
2354*da0073e9SAndroid Build Coastguard Worker     auto output = F::interpolate(tensor, options);
2355*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
2356*da0073e9SAndroid Build Coastguard Worker   }
2357*da0073e9SAndroid Build Coastguard Worker   {
2358*da0073e9SAndroid Build Coastguard Worker     auto tensor = torch::rand({2, 3, 32, 32});
2359*da0073e9SAndroid Build Coastguard Worker     std::vector<int64_t> osize = {8, 10};
2360*da0073e9SAndroid Build Coastguard Worker     auto expected = at::native::_upsample_bicubic2d_aa(
2361*da0073e9SAndroid Build Coastguard Worker         tensor, osize, false, torch::nullopt);
2362*da0073e9SAndroid Build Coastguard Worker 
2363*da0073e9SAndroid Build Coastguard Worker     auto options = F::InterpolateFuncOptions()
2364*da0073e9SAndroid Build Coastguard Worker                        .size(osize)
2365*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kBicubic)
2366*da0073e9SAndroid Build Coastguard Worker                        .align_corners(false)
2367*da0073e9SAndroid Build Coastguard Worker                        .antialias(true);
2368*da0073e9SAndroid Build Coastguard Worker     auto output = F::interpolate(tensor, options);
2369*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
2370*da0073e9SAndroid Build Coastguard Worker   }
2371*da0073e9SAndroid Build Coastguard Worker }
2372*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Pad1)2373*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad1) {
2374*da0073e9SAndroid Build Coastguard Worker   {
2375*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
2376*da0073e9SAndroid Build Coastguard Worker     auto output =
2377*da0073e9SAndroid Build Coastguard Worker         F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular));
2378*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2379*da0073e9SAndroid Build Coastguard Worker         {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
2380*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
2381*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2382*da0073e9SAndroid Build Coastguard Worker   }
2383*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad2)2384*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad2) {
2385*da0073e9SAndroid Build Coastguard Worker   {
2386*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
2387*da0073e9SAndroid Build Coastguard Worker     auto output =
2388*da0073e9SAndroid Build Coastguard Worker         F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular));
2389*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2390*da0073e9SAndroid Build Coastguard Worker         {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2391*da0073e9SAndroid Build Coastguard Worker            {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2392*da0073e9SAndroid Build Coastguard Worker            {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2393*da0073e9SAndroid Build Coastguard Worker            {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2394*da0073e9SAndroid Build Coastguard Worker            {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2395*da0073e9SAndroid Build Coastguard Worker            {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2396*da0073e9SAndroid Build Coastguard Worker            {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}},
2397*da0073e9SAndroid Build Coastguard Worker         torch::kDouble);
2398*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
2399*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2400*da0073e9SAndroid Build Coastguard Worker   }
2401*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad3)2402*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad3) {
2403*da0073e9SAndroid Build Coastguard Worker   {
2404*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2405*da0073e9SAndroid Build Coastguard Worker     auto output = F::pad(
2406*da0073e9SAndroid Build Coastguard Worker         input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
2407*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2408*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2409*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2410*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2411*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2412*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2413*da0073e9SAndroid Build Coastguard Worker 
2414*da0073e9SAndroid Build Coastguard Worker            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2415*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2416*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2417*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2418*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2419*da0073e9SAndroid Build Coastguard Worker 
2420*da0073e9SAndroid Build Coastguard Worker            {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2421*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2422*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2423*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2424*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2425*da0073e9SAndroid Build Coastguard Worker 
2426*da0073e9SAndroid Build Coastguard Worker            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2427*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2428*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2429*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2430*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2431*da0073e9SAndroid Build Coastguard Worker 
2432*da0073e9SAndroid Build Coastguard Worker            {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2433*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2434*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2435*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2436*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2437*da0073e9SAndroid Build Coastguard Worker 
2438*da0073e9SAndroid Build Coastguard Worker            {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2439*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2440*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2441*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2442*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}},
2443*da0073e9SAndroid Build Coastguard Worker         torch::kDouble);
2444*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
2445*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2446*da0073e9SAndroid Build Coastguard Worker   }
2447*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad4)2448*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad4) {
2449*da0073e9SAndroid Build Coastguard Worker   {
2450*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
2451*da0073e9SAndroid Build Coastguard Worker     auto output =
2452*da0073e9SAndroid Build Coastguard Worker         F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect));
2453*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2454*da0073e9SAndroid Build Coastguard Worker         {{{{3., 2., 3., 2.},
2455*da0073e9SAndroid Build Coastguard Worker            {1., 0., 1., 0.},
2456*da0073e9SAndroid Build Coastguard Worker            {3., 2., 3., 2.},
2457*da0073e9SAndroid Build Coastguard Worker            {1., 0., 1., 0.}},
2458*da0073e9SAndroid Build Coastguard Worker 
2459*da0073e9SAndroid Build Coastguard Worker           {{7., 6., 7., 6.},
2460*da0073e9SAndroid Build Coastguard Worker            {5., 4., 5., 4.},
2461*da0073e9SAndroid Build Coastguard Worker            {7., 6., 7., 6.},
2462*da0073e9SAndroid Build Coastguard Worker            {5., 4., 5., 4.}}},
2463*da0073e9SAndroid Build Coastguard Worker 
2464*da0073e9SAndroid Build Coastguard Worker          {{{11., 10., 11., 10.},
2465*da0073e9SAndroid Build Coastguard Worker            {9., 8., 9., 8.},
2466*da0073e9SAndroid Build Coastguard Worker            {11., 10., 11., 10.},
2467*da0073e9SAndroid Build Coastguard Worker            {9., 8., 9., 8.}},
2468*da0073e9SAndroid Build Coastguard Worker 
2469*da0073e9SAndroid Build Coastguard Worker           {{15., 14., 15., 14.},
2470*da0073e9SAndroid Build Coastguard Worker            {13., 12., 13., 12.},
2471*da0073e9SAndroid Build Coastguard Worker            {15., 14., 15., 14.},
2472*da0073e9SAndroid Build Coastguard Worker            {13., 12., 13., 12.}}}},
2473*da0073e9SAndroid Build Coastguard Worker         torch::kDouble);
2474*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
2475*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2476*da0073e9SAndroid Build Coastguard Worker   }
2477*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad5)2478*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad5) {
2479*da0073e9SAndroid Build Coastguard Worker   {
2480*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2481*da0073e9SAndroid Build Coastguard Worker     auto output = F::pad(
2482*da0073e9SAndroid Build Coastguard Worker         input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
2483*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2484*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 0., 1., 2., 2., 2.},
2485*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 2., 2., 2.},
2486*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 2., 2., 2.},
2487*da0073e9SAndroid Build Coastguard Worker             {3., 3., 4., 5., 5., 5.},
2488*da0073e9SAndroid Build Coastguard Worker             {3., 3., 4., 5., 5., 5.}},
2489*da0073e9SAndroid Build Coastguard Worker 
2490*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 1., 2., 2., 2.},
2491*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 2., 2., 2.},
2492*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 2., 2., 2.},
2493*da0073e9SAndroid Build Coastguard Worker             {3., 3., 4., 5., 5., 5.},
2494*da0073e9SAndroid Build Coastguard Worker             {3., 3., 4., 5., 5., 5.}},
2495*da0073e9SAndroid Build Coastguard Worker 
2496*da0073e9SAndroid Build Coastguard Worker            {{6., 6., 7., 8., 8., 8.},
2497*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2498*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2499*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.},
2500*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.}},
2501*da0073e9SAndroid Build Coastguard Worker 
2502*da0073e9SAndroid Build Coastguard Worker            {{6., 6., 7., 8., 8., 8.},
2503*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2504*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2505*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.},
2506*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.}},
2507*da0073e9SAndroid Build Coastguard Worker 
2508*da0073e9SAndroid Build Coastguard Worker            {{6., 6., 7., 8., 8., 8.},
2509*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2510*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 8., 8., 8.},
2511*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.},
2512*da0073e9SAndroid Build Coastguard Worker             {9., 9., 10., 11., 11., 11.}}}}},
2513*da0073e9SAndroid Build Coastguard Worker         torch::kDouble);
2514*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
2515*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2516*da0073e9SAndroid Build Coastguard Worker   }
2517*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad6)2518*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad6) {
2519*da0073e9SAndroid Build Coastguard Worker   {
2520*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3});
2521*da0073e9SAndroid Build Coastguard Worker     auto output = F::pad(
2522*da0073e9SAndroid Build Coastguard Worker         input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect));
2523*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
2524*da0073e9SAndroid Build Coastguard Worker         {{{{{9., 10., 11., 10., 9.},
2525*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 7., 6.},
2526*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 10., 9.}},
2527*da0073e9SAndroid Build Coastguard Worker 
2528*da0073e9SAndroid Build Coastguard Worker            {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}},
2529*da0073e9SAndroid Build Coastguard Worker 
2530*da0073e9SAndroid Build Coastguard Worker            {{9., 10., 11., 10., 9.},
2531*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 7., 6.},
2532*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 10., 9.}},
2533*da0073e9SAndroid Build Coastguard Worker 
2534*da0073e9SAndroid Build Coastguard Worker            {{15., 16., 17., 16., 15.},
2535*da0073e9SAndroid Build Coastguard Worker             {12., 13., 14., 13., 12.},
2536*da0073e9SAndroid Build Coastguard Worker             {15., 16., 17., 16., 15.}},
2537*da0073e9SAndroid Build Coastguard Worker 
2538*da0073e9SAndroid Build Coastguard Worker            {{9., 10., 11., 10., 9.},
2539*da0073e9SAndroid Build Coastguard Worker             {6., 7., 8., 7., 6.},
2540*da0073e9SAndroid Build Coastguard Worker             {9., 10., 11., 10., 9.}},
2541*da0073e9SAndroid Build Coastguard Worker 
2542*da0073e9SAndroid Build Coastguard Worker            {{3., 4., 5., 4., 3.},
2543*da0073e9SAndroid Build Coastguard Worker             {0., 1., 2., 1., 0.},
2544*da0073e9SAndroid Build Coastguard Worker             {3., 4., 5., 4., 3.}}}}},
2545*da0073e9SAndroid Build Coastguard Worker         torch::kDouble);
2546*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 3, 5}));
2547*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected, 1e-04));
2548*da0073e9SAndroid Build Coastguard Worker   }
2549*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad7)2550*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad7) {
2551*da0073e9SAndroid Build Coastguard Worker   {
2552*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2553*da0073e9SAndroid Build Coastguard Worker     auto output = F::pad(
2554*da0073e9SAndroid Build Coastguard Worker         input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0));
2555*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2556*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2557*da0073e9SAndroid Build Coastguard Worker   }
2558*da0073e9SAndroid Build Coastguard Worker }
TEST_F(FunctionalTest,Pad8)2559*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Pad8) {
2560*da0073e9SAndroid Build Coastguard Worker   {
2561*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2562*da0073e9SAndroid Build Coastguard Worker     auto output = F::pad(input, F::PadFuncOptions({1, 1}));
2563*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2564*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2565*da0073e9SAndroid Build Coastguard Worker   }
2566*da0073e9SAndroid Build Coastguard Worker }
2567*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,CTCLoss)2568*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, CTCLoss) {
2569*da0073e9SAndroid Build Coastguard Worker   { // test CTCLoss typechecks
2570*da0073e9SAndroid Build Coastguard Worker     const auto target_lengths = torch::tensor({30, 25, 20});
2571*da0073e9SAndroid Build Coastguard Worker     const auto input_lengths = torch::tensor({50, 50, 50});
2572*da0073e9SAndroid Build Coastguard Worker     const auto targets =
2573*da0073e9SAndroid Build Coastguard Worker         torch::randint(1, 15, {target_lengths.sum().item<int>()}, torch::kInt);
2574*da0073e9SAndroid Build Coastguard Worker     const auto log_probs =
2575*da0073e9SAndroid Build Coastguard Worker         torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2576*da0073e9SAndroid Build Coastguard Worker 
2577*da0073e9SAndroid Build Coastguard Worker     const auto _input_lengths = input_lengths.to(torch::kFloat);
2578*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2579*da0073e9SAndroid Build Coastguard Worker         F::ctc_loss(log_probs, targets, _input_lengths, target_lengths),
2580*da0073e9SAndroid Build Coastguard Worker         "input_lengths must be integral");
2581*da0073e9SAndroid Build Coastguard Worker 
2582*da0073e9SAndroid Build Coastguard Worker     const auto target_lengths_ = target_lengths.to(torch::kFloat);
2583*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2584*da0073e9SAndroid Build Coastguard Worker         F::ctc_loss(log_probs, targets, input_lengths, target_lengths_),
2585*da0073e9SAndroid Build Coastguard Worker         "target_lengths must be integral");
2586*da0073e9SAndroid Build Coastguard Worker   }
2587*da0073e9SAndroid Build Coastguard Worker   { // test CTCLoss length checks
2588*da0073e9SAndroid Build Coastguard Worker     const auto target_lengths = torch::tensor({30, 25, 20});
2589*da0073e9SAndroid Build Coastguard Worker     const auto input_lengths = torch::tensor({50, 50, 50});
2590*da0073e9SAndroid Build Coastguard Worker     const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt);
2591*da0073e9SAndroid Build Coastguard Worker     const auto log_probs =
2592*da0073e9SAndroid Build Coastguard Worker         torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2593*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2594*da0073e9SAndroid Build Coastguard Worker         F::ctc_loss(log_probs, targets, input_lengths, target_lengths),
2595*da0073e9SAndroid Build Coastguard Worker         "Expected tensor to have size at least 30 at dimension 1");
2596*da0073e9SAndroid Build Coastguard Worker   }
2597*da0073e9SAndroid Build Coastguard Worker   { // test CTCLoss empty target
2598*da0073e9SAndroid Build Coastguard Worker     {
2599*da0073e9SAndroid Build Coastguard Worker       const auto target_lengths = torch::tensor({0, 0, 0});
2600*da0073e9SAndroid Build Coastguard Worker       const auto input_lengths = torch::tensor({50, 50, 50});
2601*da0073e9SAndroid Build Coastguard Worker       const auto targets =
2602*da0073e9SAndroid Build Coastguard Worker           torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
2603*da0073e9SAndroid Build Coastguard Worker       const auto log_probs =
2604*da0073e9SAndroid Build Coastguard Worker           torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2605*da0073e9SAndroid Build Coastguard Worker       const auto loss = F::ctc_loss(
2606*da0073e9SAndroid Build Coastguard Worker           log_probs,
2607*da0073e9SAndroid Build Coastguard Worker           targets,
2608*da0073e9SAndroid Build Coastguard Worker           input_lengths,
2609*da0073e9SAndroid Build Coastguard Worker           target_lengths,
2610*da0073e9SAndroid Build Coastguard Worker           F::CTCLossFuncOptions().reduction(torch::kNone));
2611*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(loss.ge(0).all().item<bool>());
2612*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(
2613*da0073e9SAndroid Build Coastguard Worker           -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss));
2614*da0073e9SAndroid Build Coastguard Worker     }
2615*da0073e9SAndroid Build Coastguard Worker     {
2616*da0073e9SAndroid Build Coastguard Worker       const auto target_lengths = torch::tensor({0, 9, 0});
2617*da0073e9SAndroid Build Coastguard Worker       const auto input_lengths = torch::tensor({50, 50, 50});
2618*da0073e9SAndroid Build Coastguard Worker       const auto targets = torch::randint(1, 15, {9}, torch::kLong);
2619*da0073e9SAndroid Build Coastguard Worker       const auto log_probs =
2620*da0073e9SAndroid Build Coastguard Worker           torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2621*da0073e9SAndroid Build Coastguard Worker       const auto loss = F::ctc_loss(
2622*da0073e9SAndroid Build Coastguard Worker           log_probs,
2623*da0073e9SAndroid Build Coastguard Worker           targets,
2624*da0073e9SAndroid Build Coastguard Worker           input_lengths,
2625*da0073e9SAndroid Build Coastguard Worker           target_lengths,
2626*da0073e9SAndroid Build Coastguard Worker           F::CTCLossFuncOptions().reduction(torch::kNone));
2627*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(loss.ge(0).all().item<bool>());
2628*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(
2629*da0073e9SAndroid Build Coastguard Worker           -log_probs.sum(0)
2630*da0073e9SAndroid Build Coastguard Worker                .index_select(0, torch::tensor({0, 2}, torch::kLong))
2631*da0073e9SAndroid Build Coastguard Worker                .slice(1, 0, 1)
2632*da0073e9SAndroid Build Coastguard Worker                .view({2}),
2633*da0073e9SAndroid Build Coastguard Worker           loss.index_select(0, torch::tensor({0, 2}, torch::kLong))));
2634*da0073e9SAndroid Build Coastguard Worker     }
2635*da0073e9SAndroid Build Coastguard Worker   }
2636*da0073e9SAndroid Build Coastguard Worker }
2637*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,PoissonNLLLoss)2638*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, PoissonNLLLoss) {
2639*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::tensor({0.5, 1.5, 2.5});
2640*da0073e9SAndroid Build Coastguard Worker   const auto target = torch::tensor({1., 2., 3.});
2641*da0073e9SAndroid Build Coastguard Worker   const auto component_wise_loss = torch::exp(input) - target * input;
2642*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
2643*da0073e9SAndroid Build Coastguard Worker       torch::mean(component_wise_loss), F::poisson_nll_loss(input, target)));
2644*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
2645*da0073e9SAndroid Build Coastguard Worker       component_wise_loss,
2646*da0073e9SAndroid Build Coastguard Worker       F::poisson_nll_loss(
2647*da0073e9SAndroid Build Coastguard Worker           input,
2648*da0073e9SAndroid Build Coastguard Worker           target,
2649*da0073e9SAndroid Build Coastguard Worker           F::PoissonNLLLossFuncOptions().reduction(torch::kNone))));
2650*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
2651*da0073e9SAndroid Build Coastguard Worker       torch::sum(component_wise_loss),
2652*da0073e9SAndroid Build Coastguard Worker       F::poisson_nll_loss(
2653*da0073e9SAndroid Build Coastguard Worker           input,
2654*da0073e9SAndroid Build Coastguard Worker           target,
2655*da0073e9SAndroid Build Coastguard Worker           F::PoissonNLLLossFuncOptions().reduction(torch::kSum))));
2656*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
2657*da0073e9SAndroid Build Coastguard Worker       torch::mean(component_wise_loss),
2658*da0073e9SAndroid Build Coastguard Worker       F::poisson_nll_loss(
2659*da0073e9SAndroid Build Coastguard Worker           input,
2660*da0073e9SAndroid Build Coastguard Worker           target,
2661*da0073e9SAndroid Build Coastguard Worker           F::PoissonNLLLossFuncOptions().reduction(torch::kMean))));
2662*da0073e9SAndroid Build Coastguard Worker }
2663*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,MarginRankingLoss)2664*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, MarginRankingLoss) {
2665*da0073e9SAndroid Build Coastguard Worker   {
2666*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
2667*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
2668*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
2669*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
2670*da0073e9SAndroid Build Coastguard Worker         F::margin_ranking_loss(input1, input2, target),
2671*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2)).clamp(0).mean()));
2672*da0073e9SAndroid Build Coastguard Worker   }
2673*da0073e9SAndroid Build Coastguard Worker   {
2674*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
2675*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
2676*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
2677*da0073e9SAndroid Build Coastguard Worker     const auto margin = 0.5;
2678*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
2679*da0073e9SAndroid Build Coastguard Worker         F::margin_ranking_loss(
2680*da0073e9SAndroid Build Coastguard Worker             input1,
2681*da0073e9SAndroid Build Coastguard Worker             input2,
2682*da0073e9SAndroid Build Coastguard Worker             target,
2683*da0073e9SAndroid Build Coastguard Worker             F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2684*da0073e9SAndroid Build Coastguard Worker                 torch::kSum)),
2685*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2) + margin).clamp(0).sum()));
2686*da0073e9SAndroid Build Coastguard Worker   }
2687*da0073e9SAndroid Build Coastguard Worker   {
2688*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
2689*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
2690*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
2691*da0073e9SAndroid Build Coastguard Worker     const auto margin = 0.5;
2692*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
2693*da0073e9SAndroid Build Coastguard Worker         F::margin_ranking_loss(
2694*da0073e9SAndroid Build Coastguard Worker             input1,
2695*da0073e9SAndroid Build Coastguard Worker             input2,
2696*da0073e9SAndroid Build Coastguard Worker             target,
2697*da0073e9SAndroid Build Coastguard Worker             F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2698*da0073e9SAndroid Build Coastguard Worker                 torch::kMean)),
2699*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2) + margin).clamp(0).mean()));
2700*da0073e9SAndroid Build Coastguard Worker   }
2701*da0073e9SAndroid Build Coastguard Worker }
2702*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ConvTranspose1d)2703*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ConvTranspose1d) {
2704*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(20.).view({2, 2, 5});
2705*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange(18.).view({2, 3, 3});
2706*da0073e9SAndroid Build Coastguard Worker   auto y =
2707*da0073e9SAndroid Build Coastguard Worker       F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
2708*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2709*da0073e9SAndroid Build Coastguard Worker       {{{45., 104., 179., 212., 245., 188., 107.},
2710*da0073e9SAndroid Build Coastguard Worker         {60., 140., 242., 293., 344., 260., 146.},
2711*da0073e9SAndroid Build Coastguard Worker         {75., 176., 305., 374., 443., 332., 185.}},
2712*da0073e9SAndroid Build Coastguard Worker        {{135., 304., 509., 542., 575., 428., 237.},
2713*da0073e9SAndroid Build Coastguard Worker         {210., 460., 752., 803., 854., 620., 336.},
2714*da0073e9SAndroid Build Coastguard Worker         {285., 616., 995., 1064., 1133., 812., 435.}}});
2715*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
2716*da0073e9SAndroid Build Coastguard Worker 
2717*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv_transpose1d(x, weight);
2718*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2719*da0073e9SAndroid Build Coastguard Worker }
2720*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ConvTranspose2dEven)2721*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ConvTranspose2dEven) {
2722*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(50.).view({1, 2, 5, 5});
2723*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange(54.).view({2, 3, 3, 3});
2724*da0073e9SAndroid Build Coastguard Worker   auto y =
2725*da0073e9SAndroid Build Coastguard Worker       F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2726*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2727*da0073e9SAndroid Build Coastguard Worker       {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
2728*da0073e9SAndroid Build Coastguard Worker          {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
2729*da0073e9SAndroid Build Coastguard Worker          {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
2730*da0073e9SAndroid Build Coastguard Worker          {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
2731*da0073e9SAndroid Build Coastguard Worker          {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
2732*da0073e9SAndroid Build Coastguard Worker          {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
2733*da0073e9SAndroid Build Coastguard Worker          {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
2734*da0073e9SAndroid Build Coastguard Worker         {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
2735*da0073e9SAndroid Build Coastguard Worker          {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
2736*da0073e9SAndroid Build Coastguard Worker          {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
2737*da0073e9SAndroid Build Coastguard Worker          {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
2738*da0073e9SAndroid Build Coastguard Worker          {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
2739*da0073e9SAndroid Build Coastguard Worker          {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
2740*da0073e9SAndroid Build Coastguard Worker          {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
2741*da0073e9SAndroid Build Coastguard Worker         {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
2742*da0073e9SAndroid Build Coastguard Worker          {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
2743*da0073e9SAndroid Build Coastguard Worker          {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
2744*da0073e9SAndroid Build Coastguard Worker          {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
2745*da0073e9SAndroid Build Coastguard Worker          {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
2746*da0073e9SAndroid Build Coastguard Worker          {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
2747*da0073e9SAndroid Build Coastguard Worker          {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
2748*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
2749*da0073e9SAndroid Build Coastguard Worker 
2750*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv_transpose2d(x, weight);
2751*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2752*da0073e9SAndroid Build Coastguard Worker }
2753*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ConvTranspose2dUneven)2754*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ConvTranspose2dUneven) {
2755*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(40.).view({1, 2, 5, 4});
2756*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange(36.).view({2, 3, 3, 2});
2757*da0073e9SAndroid Build Coastguard Worker   auto y =
2758*da0073e9SAndroid Build Coastguard Worker       F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2759*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2760*da0073e9SAndroid Build Coastguard Worker       {{{{360., 758., 796., 834., 440.},
2761*da0073e9SAndroid Build Coastguard Worker          {832., 1752., 1836., 1920., 1012.},
2762*da0073e9SAndroid Build Coastguard Worker          {1432., 3014., 3152., 3290., 1732.},
2763*da0073e9SAndroid Build Coastguard Worker          {1696., 3566., 3704., 3842., 2020.},
2764*da0073e9SAndroid Build Coastguard Worker          {1960., 4118., 4256., 4394., 2308.},
2765*da0073e9SAndroid Build Coastguard Worker          {1504., 3152., 3252., 3352., 1756.},
2766*da0073e9SAndroid Build Coastguard Worker          {856., 1790., 1844., 1898., 992.}},
2767*da0073e9SAndroid Build Coastguard Worker         {{480., 1010., 1072., 1134., 596.},
2768*da0073e9SAndroid Build Coastguard Worker          {1120., 2352., 2484., 2616., 1372.},
2769*da0073e9SAndroid Build Coastguard Worker          {1936., 4058., 4268., 4478., 2344.},
2770*da0073e9SAndroid Build Coastguard Worker          {2344., 4898., 5108., 5318., 2776.},
2771*da0073e9SAndroid Build Coastguard Worker          {2752., 5738., 5948., 6158., 3208.},
2772*da0073e9SAndroid Build Coastguard Worker          {2080., 4328., 4476., 4624., 2404.},
2773*da0073e9SAndroid Build Coastguard Worker          {1168., 2426., 2504., 2582., 1340.}},
2774*da0073e9SAndroid Build Coastguard Worker         {{600., 1262., 1348., 1434., 752.},
2775*da0073e9SAndroid Build Coastguard Worker          {1408., 2952., 3132., 3312., 1732.},
2776*da0073e9SAndroid Build Coastguard Worker          {2440., 5102., 5384., 5666., 2956.},
2777*da0073e9SAndroid Build Coastguard Worker          {2992., 6230., 6512., 6794., 3532.},
2778*da0073e9SAndroid Build Coastguard Worker          {3544., 7358., 7640., 7922., 4108.},
2779*da0073e9SAndroid Build Coastguard Worker          {2656., 5504., 5700., 5896., 3052.},
2780*da0073e9SAndroid Build Coastguard Worker          {1480., 3062., 3164., 3266., 1688.}}}});
2781*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
2782*da0073e9SAndroid Build Coastguard Worker 
2783*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv_transpose2d(x, weight);
2784*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2785*da0073e9SAndroid Build Coastguard Worker }
2786*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,ConvTranspose3d)2787*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, ConvTranspose3d) {
2788*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(16.).view({1, 2, 2, 2, 2});
2789*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::arange(32.).view({2, 2, 2, 2, 2});
2790*da0073e9SAndroid Build Coastguard Worker   auto y =
2791*da0073e9SAndroid Build Coastguard Worker       F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
2792*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2793*da0073e9SAndroid Build Coastguard Worker       {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
2794*da0073e9SAndroid Build Coastguard Worker          {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
2795*da0073e9SAndroid Build Coastguard Worker          {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
2796*da0073e9SAndroid Build Coastguard Worker         {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
2797*da0073e9SAndroid Build Coastguard Worker          {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
2798*da0073e9SAndroid Build Coastguard Worker          {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
2799*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
2800*da0073e9SAndroid Build Coastguard Worker 
2801*da0073e9SAndroid Build Coastguard Worker   auto y_no_options = F::conv_transpose3d(x, weight);
2802*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y_no_options, expected));
2803*da0073e9SAndroid Build Coastguard Worker }
2804*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AlphaDropout)2805*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AlphaDropout) {
2806*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn(5000);
2807*da0073e9SAndroid Build Coastguard Worker   auto input_mean = input.mean();
2808*da0073e9SAndroid Build Coastguard Worker   auto input_std = input.std();
2809*da0073e9SAndroid Build Coastguard Worker 
2810*da0073e9SAndroid Build Coastguard Worker   for (const auto rate : {0.2, 0.5, 0.8}) {
2811*da0073e9SAndroid Build Coastguard Worker     for (const auto inplace : {false, true}) {
2812*da0073e9SAndroid Build Coastguard Worker       auto input_ = input.clone();
2813*da0073e9SAndroid Build Coastguard Worker       auto output = F::alpha_dropout(
2814*da0073e9SAndroid Build Coastguard Worker           input_,
2815*da0073e9SAndroid Build Coastguard Worker           F::AlphaDropoutFuncOptions().p(rate).training(false).inplace(
2816*da0073e9SAndroid Build Coastguard Worker               inplace));
2817*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2818*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2819*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
2820*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(input_, output));
2821*da0073e9SAndroid Build Coastguard Worker       }
2822*da0073e9SAndroid Build Coastguard Worker     }
2823*da0073e9SAndroid Build Coastguard Worker   }
2824*da0073e9SAndroid Build Coastguard Worker   auto output = F::detail::alpha_dropout(input, 0.5, false, false);
2825*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2826*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2827*da0073e9SAndroid Build Coastguard Worker }
2828*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,FeatureAlphaDropout)2829*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, FeatureAlphaDropout) {
2830*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn(5000);
2831*da0073e9SAndroid Build Coastguard Worker   auto input_mean = input.mean();
2832*da0073e9SAndroid Build Coastguard Worker   auto input_std = input.std();
2833*da0073e9SAndroid Build Coastguard Worker 
2834*da0073e9SAndroid Build Coastguard Worker   for (const auto rate : {0.2, 0.5, 0.8}) {
2835*da0073e9SAndroid Build Coastguard Worker     for (const auto inplace : {false, true}) {
2836*da0073e9SAndroid Build Coastguard Worker       auto input_ = input.clone();
2837*da0073e9SAndroid Build Coastguard Worker       auto output = F::feature_alpha_dropout(
2838*da0073e9SAndroid Build Coastguard Worker           input_,
2839*da0073e9SAndroid Build Coastguard Worker           F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace(
2840*da0073e9SAndroid Build Coastguard Worker               inplace));
2841*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2842*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2843*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
2844*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(input_, output));
2845*da0073e9SAndroid Build Coastguard Worker       }
2846*da0073e9SAndroid Build Coastguard Worker     }
2847*da0073e9SAndroid Build Coastguard Worker   }
2848*da0073e9SAndroid Build Coastguard Worker   auto output = F::feature_alpha_dropout(input);
2849*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2850*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2851*da0073e9SAndroid Build Coastguard Worker }
2852*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Dropout)2853*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Dropout) {
2854*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn(5000);
2855*da0073e9SAndroid Build Coastguard Worker   auto input_mean = input.mean();
2856*da0073e9SAndroid Build Coastguard Worker   auto input_std = input.std();
2857*da0073e9SAndroid Build Coastguard Worker 
2858*da0073e9SAndroid Build Coastguard Worker   for (const auto rate : {0.2, 0.5, 0.8}) {
2859*da0073e9SAndroid Build Coastguard Worker     auto output = F::dropout(input, F::DropoutFuncOptions().p(rate));
2860*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2861*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2862*da0073e9SAndroid Build Coastguard Worker   }
2863*da0073e9SAndroid Build Coastguard Worker   auto output = F::dropout(input);
2864*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2865*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2866*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::dropout(torch::tensor(1.)).defined());
2867*da0073e9SAndroid Build Coastguard Worker }
2868*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Dropout2d)2869*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Dropout2d) {
2870*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, 2, 50, 100});
2871*da0073e9SAndroid Build Coastguard Worker   auto input_mean = input.mean();
2872*da0073e9SAndroid Build Coastguard Worker   auto input_std = input.std();
2873*da0073e9SAndroid Build Coastguard Worker 
2874*da0073e9SAndroid Build Coastguard Worker   for (const auto rate : {0.2, 0.5, 0.8}) {
2875*da0073e9SAndroid Build Coastguard Worker     auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate));
2876*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2877*da0073e9SAndroid Build Coastguard Worker   }
2878*da0073e9SAndroid Build Coastguard Worker   auto output = F::dropout2d(input);
2879*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2880*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::dropout2d(torch::randn({2, 50, 100})).defined());
2881*da0073e9SAndroid Build Coastguard Worker }
2882*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,Dropout3d)2883*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, Dropout3d) {
2884*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({2, 2, 50, 10, 10});
2885*da0073e9SAndroid Build Coastguard Worker   auto input_mean = input.mean();
2886*da0073e9SAndroid Build Coastguard Worker   auto input_std = input.std();
2887*da0073e9SAndroid Build Coastguard Worker 
2888*da0073e9SAndroid Build Coastguard Worker   for (const auto rate : {0.2, 0.5, 0.8}) {
2889*da0073e9SAndroid Build Coastguard Worker     auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate));
2890*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2891*da0073e9SAndroid Build Coastguard Worker   }
2892*da0073e9SAndroid Build Coastguard Worker   auto output = F::dropout3d(input);
2893*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2894*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined());
2895*da0073e9SAndroid Build Coastguard Worker }
2896*da0073e9SAndroid Build Coastguard Worker 
2897*da0073e9SAndroid Build Coastguard Worker template <c10::ScalarType S, typename T>
test_isfinite(const at::Device & device)2898*da0073e9SAndroid Build Coastguard Worker void test_isfinite(const at::Device& device) {
2899*da0073e9SAndroid Build Coastguard Worker   const std::vector<T> values = {
2900*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::lowest(),
2901*da0073e9SAndroid Build Coastguard Worker       0,
2902*da0073e9SAndroid Build Coastguard Worker       1,
2903*da0073e9SAndroid Build Coastguard Worker       42,
2904*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::min(),
2905*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::max()};
2906*da0073e9SAndroid Build Coastguard Worker   for (const auto value : values) {
2907*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::full(
2908*da0073e9SAndroid Build Coastguard Worker         {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2909*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::isfinite(x).all().template item<bool>());
2910*da0073e9SAndroid Build Coastguard Worker   }
2911*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_infinity) {
2912*da0073e9SAndroid Build Coastguard Worker     const auto inf = std::numeric_limits<T>::infinity();
2913*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
2914*da0073e9SAndroid Build Coastguard Worker         {-inf,
2915*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::lowest(),
2916*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(0),
2917*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(1),
2918*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(42),
2919*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::min(),
2920*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::max(),
2921*da0073e9SAndroid Build Coastguard Worker          inf},
2922*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
2923*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
2924*da0073e9SAndroid Build Coastguard Worker         // torch::allclose does not support comparing torch::kBool
2925*da0073e9SAndroid Build Coastguard Worker         torch::isfinite(x).toType(torch::kInt),
2926*da0073e9SAndroid Build Coastguard Worker         torch::tensor(
2927*da0073e9SAndroid Build Coastguard Worker             {false, true, true, true, true, true, true, false},
2928*da0073e9SAndroid Build Coastguard Worker             torch::TensorOptions().device(device))
2929*da0073e9SAndroid Build Coastguard Worker             .toType(torch::kInt)));
2930*da0073e9SAndroid Build Coastguard Worker   }
2931*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_quiet_NaN) {
2932*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
2933*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::quiet_NaN()},
2934*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
2935*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2936*da0073e9SAndroid Build Coastguard Worker   }
2937*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_signaling_NaN) {
2938*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
2939*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::signaling_NaN()},
2940*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
2941*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2942*da0073e9SAndroid Build Coastguard Worker   }
2943*da0073e9SAndroid Build Coastguard Worker }
2944*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,isfinite)2945*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, isfinite) {
2946*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cpu");
2947*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kUInt8, uint8_t>(device);
2948*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt8, int8_t>(device);
2949*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt16, int16_t>(device);
2950*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt32, int32_t>(device);
2951*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt64, int64_t>(device);
2952*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kFloat32, float>(device);
2953*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kFloat64, double>(device);
2954*da0073e9SAndroid Build Coastguard Worker }
2955*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,isfinite_CUDA)2956*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, isfinite_CUDA) {
2957*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cuda");
2958*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kUInt8, uint8_t>(device);
2959*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt8, int8_t>(device);
2960*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt16, int16_t>(device);
2961*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt32, int32_t>(device);
2962*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kInt64, int64_t>(device);
2963*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kFloat32, float>(device);
2964*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kFloat64, double>(device);
2965*da0073e9SAndroid Build Coastguard Worker   test_isfinite<torch::kFloat16, c10::Half>(device);
2966*da0073e9SAndroid Build Coastguard Worker }
2967*da0073e9SAndroid Build Coastguard Worker 
2968*da0073e9SAndroid Build Coastguard Worker template <c10::ScalarType S, typename T>
test_isinf(const at::Device & device)2969*da0073e9SAndroid Build Coastguard Worker void test_isinf(const at::Device& device) {
2970*da0073e9SAndroid Build Coastguard Worker   const std::vector<T> values = {
2971*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::lowest(),
2972*da0073e9SAndroid Build Coastguard Worker       0,
2973*da0073e9SAndroid Build Coastguard Worker       1,
2974*da0073e9SAndroid Build Coastguard Worker       42,
2975*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::min(),
2976*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::max()};
2977*da0073e9SAndroid Build Coastguard Worker   for (const auto value : values) {
2978*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::full(
2979*da0073e9SAndroid Build Coastguard Worker         {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2980*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2981*da0073e9SAndroid Build Coastguard Worker   }
2982*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_infinity) {
2983*da0073e9SAndroid Build Coastguard Worker     const auto inf = std::numeric_limits<T>::infinity();
2984*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
2985*da0073e9SAndroid Build Coastguard Worker         {-inf,
2986*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::lowest(),
2987*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(0),
2988*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(1),
2989*da0073e9SAndroid Build Coastguard Worker          static_cast<T>(42),
2990*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::min(),
2991*da0073e9SAndroid Build Coastguard Worker          std::numeric_limits<T>::max(),
2992*da0073e9SAndroid Build Coastguard Worker          inf},
2993*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
2994*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
2995*da0073e9SAndroid Build Coastguard Worker         // torch::allclose does not support comparing torch::kBool
2996*da0073e9SAndroid Build Coastguard Worker         torch::isinf(x).toType(torch::kInt),
2997*da0073e9SAndroid Build Coastguard Worker         torch::tensor(
2998*da0073e9SAndroid Build Coastguard Worker             {true, false, false, false, false, false, false, true},
2999*da0073e9SAndroid Build Coastguard Worker             torch::TensorOptions().device(device))
3000*da0073e9SAndroid Build Coastguard Worker             .toType(torch::kInt)));
3001*da0073e9SAndroid Build Coastguard Worker   }
3002*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_quiet_NaN) {
3003*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
3004*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::quiet_NaN()},
3005*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3006*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3007*da0073e9SAndroid Build Coastguard Worker   }
3008*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_signaling_NaN) {
3009*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
3010*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::signaling_NaN()},
3011*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3012*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
3013*da0073e9SAndroid Build Coastguard Worker   }
3014*da0073e9SAndroid Build Coastguard Worker }
3015*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,isinf)3016*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, isinf) {
3017*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cpu");
3018*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kUInt8, uint8_t>(device);
3019*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt8, int8_t>(device);
3020*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt16, int16_t>(device);
3021*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt32, int32_t>(device);
3022*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt64, int64_t>(device);
3023*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kFloat32, float>(device);
3024*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kFloat64, double>(device);
3025*da0073e9SAndroid Build Coastguard Worker }
3026*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,isinf_CUDA)3027*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, isinf_CUDA) {
3028*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cuda");
3029*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kUInt8, uint8_t>(device);
3030*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt8, int8_t>(device);
3031*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt16, int16_t>(device);
3032*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt32, int32_t>(device);
3033*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kInt64, int64_t>(device);
3034*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kFloat32, float>(device);
3035*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kFloat64, double>(device);
3036*da0073e9SAndroid Build Coastguard Worker   test_isinf<torch::kFloat16, c10::Half>(device);
3037*da0073e9SAndroid Build Coastguard Worker }
3038*da0073e9SAndroid Build Coastguard Worker 
3039*da0073e9SAndroid Build Coastguard Worker template <c10::ScalarType S, typename T>
test_allclose(const at::Device & device)3040*da0073e9SAndroid Build Coastguard Worker void test_allclose(const at::Device& device) {
3041*da0073e9SAndroid Build Coastguard Worker   const std::vector<T> values = {
3042*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::lowest(),
3043*da0073e9SAndroid Build Coastguard Worker       0,
3044*da0073e9SAndroid Build Coastguard Worker       1,
3045*da0073e9SAndroid Build Coastguard Worker       42,
3046*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::min(),
3047*da0073e9SAndroid Build Coastguard Worker       std::numeric_limits<T>::max()};
3048*da0073e9SAndroid Build Coastguard Worker   for (const auto value : values) {
3049*da0073e9SAndroid Build Coastguard Worker     const auto x =
3050*da0073e9SAndroid Build Coastguard Worker         torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3051*da0073e9SAndroid Build Coastguard Worker     const auto y =
3052*da0073e9SAndroid Build Coastguard Worker         torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
3053*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, x));
3054*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, y));
3055*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, x));
3056*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(torch::allclose(1.1 * x + 0.1, 1.0 * x));
3057*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(0.99 * x + 0.1, 1.0 * x, 1.1, 0.1));
3058*da0073e9SAndroid Build Coastguard Worker   }
3059*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_infinity) {
3060*da0073e9SAndroid Build Coastguard Worker     const auto inf = std::numeric_limits<T>::infinity();
3061*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
3062*da0073e9SAndroid Build Coastguard Worker         {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3063*da0073e9SAndroid Build Coastguard Worker     const auto y = torch::tensor(
3064*da0073e9SAndroid Build Coastguard Worker         {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
3065*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, x));
3066*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, y));
3067*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, x));
3068*da0073e9SAndroid Build Coastguard Worker   }
3069*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_quiet_NaN) {
3070*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
3071*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::quiet_NaN()},
3072*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3073*da0073e9SAndroid Build Coastguard Worker     const auto y = torch::tensor(
3074*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::quiet_NaN()},
3075*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3076*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3077*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3078*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3079*da0073e9SAndroid Build Coastguard Worker   }
3080*da0073e9SAndroid Build Coastguard Worker   if (std::numeric_limits<T>::has_signaling_NaN) {
3081*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::tensor(
3082*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::signaling_NaN()},
3083*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3084*da0073e9SAndroid Build Coastguard Worker     const auto y = torch::tensor(
3085*da0073e9SAndroid Build Coastguard Worker         {std::numeric_limits<T>::signaling_NaN()},
3086*da0073e9SAndroid Build Coastguard Worker         torch::TensorOptions().dtype(S).device(device));
3087*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
3088*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
3089*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
3090*da0073e9SAndroid Build Coastguard Worker   }
3091*da0073e9SAndroid Build Coastguard Worker }
3092*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AllClose)3093*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AllClose) {
3094*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cpu");
3095*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kUInt8, uint8_t>(device);
3096*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt8, int8_t>(device);
3097*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt16, int16_t>(device);
3098*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt32, int32_t>(device);
3099*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt64, int64_t>(device);
3100*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kFloat32, float>(device);
3101*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kFloat64, double>(device);
3102*da0073e9SAndroid Build Coastguard Worker }
3103*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,AllClose_CUDA)3104*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, AllClose_CUDA) {
3105*da0073e9SAndroid Build Coastguard Worker   const at::Device device("cuda");
3106*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kUInt8, uint8_t>(device);
3107*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt8, int8_t>(device);
3108*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt16, int16_t>(device);
3109*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt32, int32_t>(device);
3110*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kInt64, int64_t>(device);
3111*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kFloat32, float>(device);
3112*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kFloat64, double>(device);
3113*da0073e9SAndroid Build Coastguard Worker   test_allclose<torch::kFloat16, c10::Half>(device);
3114*da0073e9SAndroid Build Coastguard Worker }
3115*da0073e9SAndroid Build Coastguard Worker 
TEST_F(FunctionalTest,BCEWithLogitsLoss)3116*da0073e9SAndroid Build Coastguard Worker TEST_F(FunctionalTest, BCEWithLogitsLoss) {
3117*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits raises if target and input are different size
3118*da0073e9SAndroid Build Coastguard Worker     {
3119*da0073e9SAndroid Build Coastguard Worker       const auto target = torch::rand(5);
3120*da0073e9SAndroid Build Coastguard Worker       const auto input = torch::rand({5, 1});
3121*da0073e9SAndroid Build Coastguard Worker       ASSERT_THROWS_WITH(
3122*da0073e9SAndroid Build Coastguard Worker           F::binary_cross_entropy_with_logits(input, target),
3123*da0073e9SAndroid Build Coastguard Worker           "must be the same as input size");
3124*da0073e9SAndroid Build Coastguard Worker     }
3125*da0073e9SAndroid Build Coastguard Worker     {
3126*da0073e9SAndroid Build Coastguard Worker       const auto target = torch::rand({5, 1});
3127*da0073e9SAndroid Build Coastguard Worker       const auto input = torch::rand(5);
3128*da0073e9SAndroid Build Coastguard Worker       ASSERT_THROWS_WITH(
3129*da0073e9SAndroid Build Coastguard Worker           F::binary_cross_entropy_with_logits(input, target),
3130*da0073e9SAndroid Build Coastguard Worker           "must be the same as input size");
3131*da0073e9SAndroid Build Coastguard Worker     }
3132*da0073e9SAndroid Build Coastguard Worker   }
3133*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits gives same result as sigmoid and bce loss
3134*da0073e9SAndroid Build Coastguard Worker     auto sigmoid = Sigmoid();
3135*da0073e9SAndroid Build Coastguard Worker 
3136*da0073e9SAndroid Build Coastguard Worker     auto target = torch::rand({64, 4});
3137*da0073e9SAndroid Build Coastguard Worker     auto output = torch::rand({64, 4}) - 0.5;
3138*da0073e9SAndroid Build Coastguard Worker 
3139*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3140*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(output, target),
3141*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy(sigmoid(output), target)));
3142*da0073e9SAndroid Build Coastguard Worker 
3143*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::rand(4);
3144*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3145*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(
3146*da0073e9SAndroid Build Coastguard Worker             output,
3147*da0073e9SAndroid Build Coastguard Worker             target,
3148*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3149*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy(
3150*da0073e9SAndroid Build Coastguard Worker             sigmoid(output),
3151*da0073e9SAndroid Build Coastguard Worker             target,
3152*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyFuncOptions().weight(weight))));
3153*da0073e9SAndroid Build Coastguard Worker 
3154*da0073e9SAndroid Build Coastguard Worker     target = torch::zeros({4, 1}, torch::kFloat);
3155*da0073e9SAndroid Build Coastguard Worker     output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3156*da0073e9SAndroid Build Coastguard Worker 
3157*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3158*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(output, target),
3159*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy(sigmoid(output), target)));
3160*da0073e9SAndroid Build Coastguard Worker 
3161*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3162*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(
3163*da0073e9SAndroid Build Coastguard Worker             output,
3164*da0073e9SAndroid Build Coastguard Worker             target,
3165*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(
3166*da0073e9SAndroid Build Coastguard Worker                 torch::kNone)),
3167*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy(
3168*da0073e9SAndroid Build Coastguard Worker             sigmoid(output),
3169*da0073e9SAndroid Build Coastguard Worker             target,
3170*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))));
3171*da0073e9SAndroid Build Coastguard Worker 
3172*da0073e9SAndroid Build Coastguard Worker     weight = torch::rand({1}, torch::kFloat);
3173*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3174*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(
3175*da0073e9SAndroid Build Coastguard Worker             output,
3176*da0073e9SAndroid Build Coastguard Worker             target,
3177*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3178*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy(
3179*da0073e9SAndroid Build Coastguard Worker             sigmoid(output),
3180*da0073e9SAndroid Build Coastguard Worker             target,
3181*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyFuncOptions().weight(weight))));
3182*da0073e9SAndroid Build Coastguard Worker   }
3183*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits has correct grad at zero
3184*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3185*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::zeros({3, 1});
3186*da0073e9SAndroid Build Coastguard Worker     F::binary_cross_entropy_with_logits(
3187*da0073e9SAndroid Build Coastguard Worker         output,
3188*da0073e9SAndroid Build Coastguard Worker         target,
3189*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum))
3190*da0073e9SAndroid Build Coastguard Worker         .backward();
3191*da0073e9SAndroid Build Coastguard Worker     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3192*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3193*da0073e9SAndroid Build Coastguard Worker   }
3194*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits broadcasts weights
3195*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({16, 4});
3196*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({16, 4}) - 0.5;
3197*da0073e9SAndroid Build Coastguard Worker 
3198*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::rand(4);
3199*da0073e9SAndroid Build Coastguard Worker     auto out1 = F::binary_cross_entropy_with_logits(
3200*da0073e9SAndroid Build Coastguard Worker         output,
3201*da0073e9SAndroid Build Coastguard Worker         target,
3202*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3203*da0073e9SAndroid Build Coastguard Worker 
3204*da0073e9SAndroid Build Coastguard Worker     weight = weight.expand({16, 4}).contiguous();
3205*da0073e9SAndroid Build Coastguard Worker     auto out2 = F::binary_cross_entropy_with_logits(
3206*da0073e9SAndroid Build Coastguard Worker         output,
3207*da0073e9SAndroid Build Coastguard Worker         target,
3208*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3209*da0073e9SAndroid Build Coastguard Worker 
3210*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3211*da0073e9SAndroid Build Coastguard Worker 
3212*da0073e9SAndroid Build Coastguard Worker     weight = torch::rand({16, 1});
3213*da0073e9SAndroid Build Coastguard Worker     out1 = F::binary_cross_entropy_with_logits(
3214*da0073e9SAndroid Build Coastguard Worker         output,
3215*da0073e9SAndroid Build Coastguard Worker         target,
3216*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3217*da0073e9SAndroid Build Coastguard Worker 
3218*da0073e9SAndroid Build Coastguard Worker     weight = weight.expand({16, 4}).contiguous();
3219*da0073e9SAndroid Build Coastguard Worker     out2 = F::binary_cross_entropy_with_logits(
3220*da0073e9SAndroid Build Coastguard Worker         output,
3221*da0073e9SAndroid Build Coastguard Worker         target,
3222*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3223*da0073e9SAndroid Build Coastguard Worker 
3224*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3225*da0073e9SAndroid Build Coastguard Worker   }
3226*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits ones in pos weights are the same as none
3227*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({64, 4});
3228*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({64, 4}) - 0.5;
3229*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::ones({64, 4});
3230*da0073e9SAndroid Build Coastguard Worker 
3231*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3232*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(output, target),
3233*da0073e9SAndroid Build Coastguard Worker         F::binary_cross_entropy_with_logits(
3234*da0073e9SAndroid Build Coastguard Worker             output,
3235*da0073e9SAndroid Build Coastguard Worker             target,
3236*da0073e9SAndroid Build Coastguard Worker             F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(
3237*da0073e9SAndroid Build Coastguard Worker                 pos_weight))));
3238*da0073e9SAndroid Build Coastguard Worker   }
3239*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits broadcasts pos weights
3240*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({64, 4});
3241*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({64, 4}) - 0.5;
3242*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::rand(4);
3243*da0073e9SAndroid Build Coastguard Worker     const auto out1 = F::binary_cross_entropy_with_logits(
3244*da0073e9SAndroid Build Coastguard Worker         output,
3245*da0073e9SAndroid Build Coastguard Worker         target,
3246*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3247*da0073e9SAndroid Build Coastguard Worker 
3248*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight1 = pos_weight.expand({1, 4});
3249*da0073e9SAndroid Build Coastguard Worker     const auto out2 = F::binary_cross_entropy_with_logits(
3250*da0073e9SAndroid Build Coastguard Worker         output,
3251*da0073e9SAndroid Build Coastguard Worker         target,
3252*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3253*da0073e9SAndroid Build Coastguard Worker 
3254*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight2 = pos_weight.expand({64, 4});
3255*da0073e9SAndroid Build Coastguard Worker     const auto out3 = F::binary_cross_entropy_with_logits(
3256*da0073e9SAndroid Build Coastguard Worker         output,
3257*da0073e9SAndroid Build Coastguard Worker         target,
3258*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3259*da0073e9SAndroid Build Coastguard Worker 
3260*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3261*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out3));
3262*da0073e9SAndroid Build Coastguard Worker   }
3263*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits with pos weight has correct grad at zero
3264*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3265*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::zeros({3, 1});
3266*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::ones({3, 1});
3267*da0073e9SAndroid Build Coastguard Worker     F::binary_cross_entropy_with_logits(
3268*da0073e9SAndroid Build Coastguard Worker         output,
3269*da0073e9SAndroid Build Coastguard Worker         target,
3270*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions()
3271*da0073e9SAndroid Build Coastguard Worker             .pos_weight(pos_weight)
3272*da0073e9SAndroid Build Coastguard Worker             .reduction(torch::kSum))
3273*da0073e9SAndroid Build Coastguard Worker         .backward();
3274*da0073e9SAndroid Build Coastguard Worker     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3275*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3276*da0073e9SAndroid Build Coastguard Worker     const auto grad = output.grad();
3277*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(grad, expected_grad));
3278*da0073e9SAndroid Build Coastguard Worker   }
3279*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits stability
3280*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::tensor({0., -120.});
3281*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::tensor({0., 1.});
3282*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::tensor({1., 1.});
3283*da0073e9SAndroid Build Coastguard Worker 
3284*da0073e9SAndroid Build Coastguard Worker     const auto out1 = F::binary_cross_entropy_with_logits(output, target);
3285*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3286*da0073e9SAndroid Build Coastguard Worker 
3287*da0073e9SAndroid Build Coastguard Worker     const auto out2 = F::binary_cross_entropy_with_logits(
3288*da0073e9SAndroid Build Coastguard Worker         output,
3289*da0073e9SAndroid Build Coastguard Worker         target,
3290*da0073e9SAndroid Build Coastguard Worker         F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3291*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3292*da0073e9SAndroid Build Coastguard Worker   }
3293*da0073e9SAndroid Build Coastguard Worker }
3294