xref: /aosp_15_r20/external/pytorch/test/cpp/api/modules.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <torch/expanding_array.h>
9*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/functional/activation.h>
10*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/options/activation.h>
11*da0073e9SAndroid Build Coastguard Worker #include <limits>
12*da0073e9SAndroid Build Coastguard Worker #include <random>
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
15*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker class TestModel : public torch::nn::Module {
18*da0073e9SAndroid Build Coastguard Worker  public:
TestModel()19*da0073e9SAndroid Build Coastguard Worker   TestModel()
20*da0073e9SAndroid Build Coastguard Worker       : l1(register_module("l1", Linear(10, 3))),
21*da0073e9SAndroid Build Coastguard Worker         l2(register_module("l2", Linear(3, 5))),
22*da0073e9SAndroid Build Coastguard Worker         l3(register_module("l3", Linear(5, 100))) {}
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker   Linear l1, l2, l3;
25*da0073e9SAndroid Build Coastguard Worker };
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker class NestedModel : public torch::nn::Module {
28*da0073e9SAndroid Build Coastguard Worker  public:
NestedModel()29*da0073e9SAndroid Build Coastguard Worker   NestedModel()
30*da0073e9SAndroid Build Coastguard Worker       : param_(register_parameter("param", torch::empty({3, 2, 21}))),
31*da0073e9SAndroid Build Coastguard Worker         l1(register_module("l1", Linear(5, 20))),
32*da0073e9SAndroid Build Coastguard Worker         t(register_module("test", std::make_shared<TestModel>())) {}
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   torch::Tensor param_;
35*da0073e9SAndroid Build Coastguard Worker   Linear l1;
36*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<TestModel> t;
37*da0073e9SAndroid Build Coastguard Worker };
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker struct ModulesTest : torch::test::SeedingFixture {};
40*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv1d)41*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv1d) {
42*da0073e9SAndroid Build Coastguard Worker   Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
43*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(
44*da0073e9SAndroid Build Coastguard Worker       torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3}));
45*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
46*da0073e9SAndroid Build Coastguard Worker                .reshape({2, 3, 5});
47*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
48*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
49*da0073e9SAndroid Build Coastguard Worker       {{{312., 348., 384.}, {798., 915., 1032.}},
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker        {{852., 888., 924.}, {2553., 2670., 2787.}}},
52*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
53*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
56*da0073e9SAndroid Build Coastguard Worker   s.backward();
57*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
58*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
59*da0073e9SAndroid Build Coastguard Worker }
60*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv1dSameStrided)61*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv1dSameStrided) {
62*da0073e9SAndroid Build Coastguard Worker   auto options = Conv1dOptions(3, 2, 3);
63*da0073e9SAndroid Build Coastguard Worker   options.stride(1).padding(torch::kSame);
64*da0073e9SAndroid Build Coastguard Worker   Conv1d model_valid(options);
65*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
66*da0073e9SAndroid Build Coastguard Worker       [&] { Conv1d model_invalid(options.stride(2)); }(),
67*da0073e9SAndroid Build Coastguard Worker       "padding='same' is not supported for strided convolutions");
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv1dIvalidArg)70*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv1dIvalidArg) {
71*da0073e9SAndroid Build Coastguard Worker   auto options = Conv1dOptions(3, 2, 3).groups(-1);
72*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
73*da0073e9SAndroid Build Coastguard Worker       Conv1d(options), "in_channels, groups and out_channels must");
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv2dEven)76*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv2dEven) {
77*da0073e9SAndroid Build Coastguard Worker   Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
78*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(
79*da0073e9SAndroid Build Coastguard Worker       torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3}));
80*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
81*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 5});
82*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
83*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
84*da0073e9SAndroid Build Coastguard Worker       {{{{15219., 15570., 15921.},
85*da0073e9SAndroid Build Coastguard Worker          {16974., 17325., 17676.},
86*da0073e9SAndroid Build Coastguard Worker          {18729., 19080., 19431.}},
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker         {{37818., 38898., 39978.},
89*da0073e9SAndroid Build Coastguard Worker          {43218., 44298., 45378.},
90*da0073e9SAndroid Build Coastguard Worker          {48618., 49698., 50778.}}}},
91*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
92*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
95*da0073e9SAndroid Build Coastguard Worker   s.backward();
96*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
97*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv2dUneven)100*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv2dUneven) {
101*da0073e9SAndroid Build Coastguard Worker   Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
102*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(
103*da0073e9SAndroid Build Coastguard Worker       torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2}));
104*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
105*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 4});
106*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
107*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
108*da0073e9SAndroid Build Coastguard Worker       {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker         {{13227., 13704., 14181.},
111*da0073e9SAndroid Build Coastguard Worker          {15135., 15612., 16089.},
112*da0073e9SAndroid Build Coastguard Worker          {17043., 17520., 17997.}}}},
113*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
114*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
117*da0073e9SAndroid Build Coastguard Worker   s.backward();
118*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
119*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv2dSameStrided)122*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv2dSameStrided) {
123*da0073e9SAndroid Build Coastguard Worker   auto options = Conv2dOptions(3, 2, {3, 4});
124*da0073e9SAndroid Build Coastguard Worker   options.stride(1).padding(torch::kSame);
125*da0073e9SAndroid Build Coastguard Worker   Conv2d model_valid(options);
126*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
127*da0073e9SAndroid Build Coastguard Worker       [&] { Conv2d model_invalid(options.stride(2)); }(),
128*da0073e9SAndroid Build Coastguard Worker       "padding='same' is not supported for strided convolutions");
129*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
130*da0073e9SAndroid Build Coastguard Worker       [&] {
131*da0073e9SAndroid Build Coastguard Worker         Conv2d model_invalid(options.stride({1, 2}));
132*da0073e9SAndroid Build Coastguard Worker       }(),
133*da0073e9SAndroid Build Coastguard Worker       "padding='same' is not supported for strided convolutions");
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv3d)136*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv3d) {
137*da0073e9SAndroid Build Coastguard Worker   Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
138*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(
139*da0073e9SAndroid Build Coastguard Worker       torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3}));
140*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
141*da0073e9SAndroid Build Coastguard Worker                .reshape({1, 3, 5, 5, 5});
142*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
143*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
144*da0073e9SAndroid Build Coastguard Worker       {{{{{700704., 703944., 707184.},
145*da0073e9SAndroid Build Coastguard Worker           {716904., 720144., 723384.},
146*da0073e9SAndroid Build Coastguard Worker           {733104., 736344., 739584.}},
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker          {{781704., 784944., 788184.},
149*da0073e9SAndroid Build Coastguard Worker           {797904., 801144., 804384.},
150*da0073e9SAndroid Build Coastguard Worker           {814104., 817344., 820584.}},
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker          {{862704., 865944., 869184.},
153*da0073e9SAndroid Build Coastguard Worker           {878904., 882144., 885384.},
154*da0073e9SAndroid Build Coastguard Worker           {895104., 898344., 901584.}}},
155*da0073e9SAndroid Build Coastguard Worker 
156*da0073e9SAndroid Build Coastguard Worker         {{{1724220., 1734021., 1743822.},
157*da0073e9SAndroid Build Coastguard Worker           {1773225., 1783026., 1792827.},
158*da0073e9SAndroid Build Coastguard Worker           {1822230., 1832031., 1841832.}},
159*da0073e9SAndroid Build Coastguard Worker 
160*da0073e9SAndroid Build Coastguard Worker          {{1969245., 1979046., 1988847.},
161*da0073e9SAndroid Build Coastguard Worker           {2018250., 2028051., 2037852.},
162*da0073e9SAndroid Build Coastguard Worker           {2067255., 2077056., 2086857.}},
163*da0073e9SAndroid Build Coastguard Worker 
164*da0073e9SAndroid Build Coastguard Worker          {{2214270., 2224071., 2233872.},
165*da0073e9SAndroid Build Coastguard Worker           {2263275., 2273076., 2282877.},
166*da0073e9SAndroid Build Coastguard Worker           {2312280., 2322081., 2331882.}}}}},
167*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
168*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
169*da0073e9SAndroid Build Coastguard Worker 
170*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
171*da0073e9SAndroid Build Coastguard Worker   s.backward();
172*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
173*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
174*da0073e9SAndroid Build Coastguard Worker }
175*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Conv3dSameStrided)176*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Conv3dSameStrided) {
177*da0073e9SAndroid Build Coastguard Worker   auto options = Conv3dOptions(3, 2, {3, 4, 5});
178*da0073e9SAndroid Build Coastguard Worker   options.stride(1).padding(torch::kSame);
179*da0073e9SAndroid Build Coastguard Worker   Conv3d model_valid(options);
180*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
181*da0073e9SAndroid Build Coastguard Worker       [&] { Conv3d model_invalid(options.stride(2)); }(),
182*da0073e9SAndroid Build Coastguard Worker       "padding='same' is not supported for strided convolutions");
183*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
184*da0073e9SAndroid Build Coastguard Worker       [&] {
185*da0073e9SAndroid Build Coastguard Worker         Conv3d model_invalid(options.stride({1, 2, 1}));
186*da0073e9SAndroid Build Coastguard Worker       }(),
187*da0073e9SAndroid Build Coastguard Worker       "padding='same' is not supported for strided convolutions");
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConvTranspose1d)190*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConvTranspose1d) {
191*da0073e9SAndroid Build Coastguard Worker   ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
192*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(torch::arange(18.).view({2, 3, 3}));
193*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(20.).reshape({2, 2, 5});
194*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
195*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
196*da0073e9SAndroid Build Coastguard Worker       {{{45., 104., 179., 212., 245., 188., 107.},
197*da0073e9SAndroid Build Coastguard Worker         {60., 140., 242., 293., 344., 260., 146.},
198*da0073e9SAndroid Build Coastguard Worker         {75., 176., 305., 374., 443., 332., 185.}},
199*da0073e9SAndroid Build Coastguard Worker        {{135., 304., 509., 542., 575., 428., 237.},
200*da0073e9SAndroid Build Coastguard Worker         {210., 460., 752., 803., 854., 620., 336.},
201*da0073e9SAndroid Build Coastguard Worker         {285., 616., 995., 1064., 1133., 812., 435.}}});
202*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
203*da0073e9SAndroid Build Coastguard Worker 
204*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
205*da0073e9SAndroid Build Coastguard Worker   s.backward();
206*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
207*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
208*da0073e9SAndroid Build Coastguard Worker }
209*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConvTranspose2dEven)210*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConvTranspose2dEven) {
211*da0073e9SAndroid Build Coastguard Worker   ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false));
212*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3}));
213*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(50.).view({1, 2, 5, 5});
214*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
215*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
216*da0073e9SAndroid Build Coastguard Worker       {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
217*da0073e9SAndroid Build Coastguard Worker          {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
218*da0073e9SAndroid Build Coastguard Worker          {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
219*da0073e9SAndroid Build Coastguard Worker          {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
220*da0073e9SAndroid Build Coastguard Worker          {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
221*da0073e9SAndroid Build Coastguard Worker          {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
222*da0073e9SAndroid Build Coastguard Worker          {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
223*da0073e9SAndroid Build Coastguard Worker         {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
224*da0073e9SAndroid Build Coastguard Worker          {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
225*da0073e9SAndroid Build Coastguard Worker          {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
226*da0073e9SAndroid Build Coastguard Worker          {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
227*da0073e9SAndroid Build Coastguard Worker          {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
228*da0073e9SAndroid Build Coastguard Worker          {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
229*da0073e9SAndroid Build Coastguard Worker          {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
230*da0073e9SAndroid Build Coastguard Worker         {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
231*da0073e9SAndroid Build Coastguard Worker          {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
232*da0073e9SAndroid Build Coastguard Worker          {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
233*da0073e9SAndroid Build Coastguard Worker          {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
234*da0073e9SAndroid Build Coastguard Worker          {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
235*da0073e9SAndroid Build Coastguard Worker          {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
236*da0073e9SAndroid Build Coastguard Worker          {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
237*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
240*da0073e9SAndroid Build Coastguard Worker   s.backward();
241*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
242*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConvTranspose2dUneven)245*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConvTranspose2dUneven) {
246*da0073e9SAndroid Build Coastguard Worker   ConvTranspose2d model(
247*da0073e9SAndroid Build Coastguard Worker       ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
248*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2}));
249*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(40.).view({1, 2, 5, 4});
250*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
251*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
252*da0073e9SAndroid Build Coastguard Worker       {{{{360., 758., 796., 834., 440.},
253*da0073e9SAndroid Build Coastguard Worker          {832., 1752., 1836., 1920., 1012.},
254*da0073e9SAndroid Build Coastguard Worker          {1432., 3014., 3152., 3290., 1732.},
255*da0073e9SAndroid Build Coastguard Worker          {1696., 3566., 3704., 3842., 2020.},
256*da0073e9SAndroid Build Coastguard Worker          {1960., 4118., 4256., 4394., 2308.},
257*da0073e9SAndroid Build Coastguard Worker          {1504., 3152., 3252., 3352., 1756.},
258*da0073e9SAndroid Build Coastguard Worker          {856., 1790., 1844., 1898., 992.}},
259*da0073e9SAndroid Build Coastguard Worker         {{480., 1010., 1072., 1134., 596.},
260*da0073e9SAndroid Build Coastguard Worker          {1120., 2352., 2484., 2616., 1372.},
261*da0073e9SAndroid Build Coastguard Worker          {1936., 4058., 4268., 4478., 2344.},
262*da0073e9SAndroid Build Coastguard Worker          {2344., 4898., 5108., 5318., 2776.},
263*da0073e9SAndroid Build Coastguard Worker          {2752., 5738., 5948., 6158., 3208.},
264*da0073e9SAndroid Build Coastguard Worker          {2080., 4328., 4476., 4624., 2404.},
265*da0073e9SAndroid Build Coastguard Worker          {1168., 2426., 2504., 2582., 1340.}},
266*da0073e9SAndroid Build Coastguard Worker         {{600., 1262., 1348., 1434., 752.},
267*da0073e9SAndroid Build Coastguard Worker          {1408., 2952., 3132., 3312., 1732.},
268*da0073e9SAndroid Build Coastguard Worker          {2440., 5102., 5384., 5666., 2956.},
269*da0073e9SAndroid Build Coastguard Worker          {2992., 6230., 6512., 6794., 3532.},
270*da0073e9SAndroid Build Coastguard Worker          {3544., 7358., 7640., 7922., 4108.},
271*da0073e9SAndroid Build Coastguard Worker          {2656., 5504., 5700., 5896., 3052.},
272*da0073e9SAndroid Build Coastguard Worker          {1480., 3062., 3164., 3266., 1688.}}}});
273*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
274*da0073e9SAndroid Build Coastguard Worker 
275*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
276*da0073e9SAndroid Build Coastguard Worker   s.backward();
277*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
278*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConvTranspose3d)281*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConvTranspose3d) {
282*da0073e9SAndroid Build Coastguard Worker   ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false));
283*da0073e9SAndroid Build Coastguard Worker   model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
284*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
285*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
286*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
287*da0073e9SAndroid Build Coastguard Worker       {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
288*da0073e9SAndroid Build Coastguard Worker          {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
289*da0073e9SAndroid Build Coastguard Worker          {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
290*da0073e9SAndroid Build Coastguard Worker         {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
291*da0073e9SAndroid Build Coastguard Worker          {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
292*da0073e9SAndroid Build Coastguard Worker          {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
293*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
294*da0073e9SAndroid Build Coastguard Worker 
295*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
296*da0073e9SAndroid Build Coastguard Worker   s.backward();
297*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
298*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(model->weight.grad().numel() == 2 * 2 * 2 * 2 * 2);
299*da0073e9SAndroid Build Coastguard Worker }
300*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool1d)301*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool1d) {
302*da0073e9SAndroid Build Coastguard Worker   MaxPool1d model(MaxPool1dOptions(3).stride(2));
303*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
304*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
305*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
306*da0073e9SAndroid Build Coastguard Worker 
307*da0073e9SAndroid Build Coastguard Worker   s.backward();
308*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
309*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
310*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
311*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
312*da0073e9SAndroid Build Coastguard Worker }
313*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool1dReturnIndices)314*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool1dReturnIndices) {
315*da0073e9SAndroid Build Coastguard Worker   MaxPool1d model(MaxPool1dOptions(3).stride(2));
316*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
317*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
318*da0073e9SAndroid Build Coastguard Worker 
319*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
320*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
321*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
322*da0073e9SAndroid Build Coastguard Worker 
323*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
324*da0073e9SAndroid Build Coastguard Worker       torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong)));
325*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 2}));
326*da0073e9SAndroid Build Coastguard Worker }
327*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool2dEven)328*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool2dEven) {
329*da0073e9SAndroid Build Coastguard Worker   MaxPool2d model(MaxPool2dOptions(3).stride(2));
330*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
331*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
332*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
333*da0073e9SAndroid Build Coastguard Worker 
334*da0073e9SAndroid Build Coastguard Worker   s.backward();
335*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
336*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
337*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
338*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
339*da0073e9SAndroid Build Coastguard Worker }
340*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool2dUneven)341*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool2dUneven) {
342*da0073e9SAndroid Build Coastguard Worker   MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
343*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 4}, torch::requires_grad());
344*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
345*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
346*da0073e9SAndroid Build Coastguard Worker 
347*da0073e9SAndroid Build Coastguard Worker   s.backward();
348*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
349*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
350*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
351*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
352*da0073e9SAndroid Build Coastguard Worker }
353*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool2dReturnIndices)354*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool2dReturnIndices) {
355*da0073e9SAndroid Build Coastguard Worker   MaxPool2d model(MaxPool2dOptions(3).stride(2));
356*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
357*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
358*da0073e9SAndroid Build Coastguard Worker 
359*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
360*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
361*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
362*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
363*da0073e9SAndroid Build Coastguard Worker       indices,
364*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong)));
365*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
366*da0073e9SAndroid Build Coastguard Worker }
367*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool3d)368*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool3d) {
369*da0073e9SAndroid Build Coastguard Worker   MaxPool3d model(MaxPool3dOptions(3).stride(2));
370*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
371*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
372*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
373*da0073e9SAndroid Build Coastguard Worker 
374*da0073e9SAndroid Build Coastguard Worker   s.backward();
375*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
376*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
377*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
378*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
379*da0073e9SAndroid Build Coastguard Worker }
380*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool3dReturnIndices)381*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool3dReturnIndices) {
382*da0073e9SAndroid Build Coastguard Worker   MaxPool3d model(MaxPool3dOptions(3).stride(2));
383*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
384*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
385*da0073e9SAndroid Build Coastguard Worker 
386*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 4);
387*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
388*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
389*da0073e9SAndroid Build Coastguard Worker 
390*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
391*da0073e9SAndroid Build Coastguard Worker       indices,
392*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
393*da0073e9SAndroid Build Coastguard Worker           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
394*da0073e9SAndroid Build Coastguard Worker            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}},
395*da0073e9SAndroid Build Coastguard Worker           torch::kLong)));
396*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
397*da0073e9SAndroid Build Coastguard Worker }
398*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AvgPool1d)399*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AvgPool1d) {
400*da0073e9SAndroid Build Coastguard Worker   AvgPool1d model(AvgPool1dOptions(3).stride(2));
401*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5}, torch::requires_grad());
402*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
403*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
404*da0073e9SAndroid Build Coastguard Worker 
405*da0073e9SAndroid Build Coastguard Worker   s.backward();
406*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
407*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
408*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
409*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
410*da0073e9SAndroid Build Coastguard Worker }
411*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AvgPool2dEven)412*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AvgPool2dEven) {
413*da0073e9SAndroid Build Coastguard Worker   AvgPool2d model(AvgPool2dOptions(3).stride(2));
414*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
415*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
416*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
417*da0073e9SAndroid Build Coastguard Worker 
418*da0073e9SAndroid Build Coastguard Worker   s.backward();
419*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
420*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
421*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
422*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
423*da0073e9SAndroid Build Coastguard Worker }
424*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AvgPool2dUneven)425*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AvgPool2dUneven) {
426*da0073e9SAndroid Build Coastguard Worker   AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2}));
427*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 4}, torch::requires_grad());
428*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
429*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
430*da0073e9SAndroid Build Coastguard Worker 
431*da0073e9SAndroid Build Coastguard Worker   s.backward();
432*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
433*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
434*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
435*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
436*da0073e9SAndroid Build Coastguard Worker }
437*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AvgPool3d)438*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AvgPool3d) {
439*da0073e9SAndroid Build Coastguard Worker   AvgPool3d model(AvgPool3dOptions(3).stride(2));
440*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
441*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
442*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
443*da0073e9SAndroid Build Coastguard Worker 
444*da0073e9SAndroid Build Coastguard Worker   s.backward();
445*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
446*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
447*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
448*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
449*da0073e9SAndroid Build Coastguard Worker }
450*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FractionalMaxPool2d)451*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FractionalMaxPool2d) {
452*da0073e9SAndroid Build Coastguard Worker   FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
453*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
454*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
455*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
456*da0073e9SAndroid Build Coastguard Worker 
457*da0073e9SAndroid Build Coastguard Worker   s.backward();
458*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
459*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
460*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
461*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
462*da0073e9SAndroid Build Coastguard Worker }
463*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FractionalMaxPool2dReturnIndices)464*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) {
465*da0073e9SAndroid Build Coastguard Worker   FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
466*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5}, torch::requires_grad());
467*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
468*da0073e9SAndroid Build Coastguard Worker 
469*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
470*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
471*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
472*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
473*da0073e9SAndroid Build Coastguard Worker       indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
474*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
475*da0073e9SAndroid Build Coastguard Worker }
476*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FractionalMaxPool3d)477*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FractionalMaxPool3d) {
478*da0073e9SAndroid Build Coastguard Worker   FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
479*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
480*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
481*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
482*da0073e9SAndroid Build Coastguard Worker 
483*da0073e9SAndroid Build Coastguard Worker   s.backward();
484*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
485*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
486*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
487*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
488*da0073e9SAndroid Build Coastguard Worker }
489*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FractionalMaxPool3dReturnIndices)490*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) {
491*da0073e9SAndroid Build Coastguard Worker   FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
492*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
493*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
494*da0073e9SAndroid Build Coastguard Worker 
495*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 4);
496*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
497*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
498*da0073e9SAndroid Build Coastguard Worker 
499*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
500*da0073e9SAndroid Build Coastguard Worker       indices,
501*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
502*da0073e9SAndroid Build Coastguard Worker           {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
503*da0073e9SAndroid Build Coastguard Worker            {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
504*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
505*da0073e9SAndroid Build Coastguard Worker }
506*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LPPool1d)507*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LPPool1d) {
508*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
509*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
510*da0073e9SAndroid Build Coastguard Worker   int kernel_size = 3;
511*da0073e9SAndroid Build Coastguard Worker 
512*da0073e9SAndroid Build Coastguard Worker   LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride));
513*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 5});
514*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
515*da0073e9SAndroid Build Coastguard Worker   auto expected =
516*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
517*da0073e9SAndroid Build Coastguard Worker        kernel_size)
518*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
519*da0073e9SAndroid Build Coastguard Worker 
520*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
521*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
522*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
523*da0073e9SAndroid Build Coastguard Worker }
524*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LPPool2d)525*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LPPool2d) {
526*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
527*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
528*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> kernel_size({2, 3});
529*da0073e9SAndroid Build Coastguard Worker 
530*da0073e9SAndroid Build Coastguard Worker   LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
531*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 2, 5});
532*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
533*da0073e9SAndroid Build Coastguard Worker   auto expected =
534*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
535*da0073e9SAndroid Build Coastguard Worker        (kernel_size[0] * kernel_size[1]))
536*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
537*da0073e9SAndroid Build Coastguard Worker 
538*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
539*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
540*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
541*da0073e9SAndroid Build Coastguard Worker }
542*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LPPool3d)543*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LPPool3d) {
544*da0073e9SAndroid Build Coastguard Worker   int norm_type = 2;
545*da0073e9SAndroid Build Coastguard Worker   int stride = 2;
546*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> kernel_size({1, 2, 3});
547*da0073e9SAndroid Build Coastguard Worker 
548*da0073e9SAndroid Build Coastguard Worker   LPPool3d model(LPPool3dOptions(norm_type, kernel_size).stride(stride));
549*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({1, 1, 1, 2, 5});
550*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
551*da0073e9SAndroid Build Coastguard Worker   auto expected =
552*da0073e9SAndroid Build Coastguard Worker       (torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
553*da0073e9SAndroid Build Coastguard Worker        (kernel_size[0] * kernel_size[1] * kernel_size[2]))
554*da0073e9SAndroid Build Coastguard Worker           .pow(1. / norm_type);
555*da0073e9SAndroid Build Coastguard Worker 
556*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 5);
557*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, expected));
558*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
559*da0073e9SAndroid Build Coastguard Worker }
560*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Identity)561*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Identity) {
562*da0073e9SAndroid Build Coastguard Worker   Identity identity;
563*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
564*da0073e9SAndroid Build Coastguard Worker       {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
565*da0073e9SAndroid Build Coastguard Worker   auto output = identity->forward(input);
566*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
567*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
568*da0073e9SAndroid Build Coastguard Worker   s.backward();
569*da0073e9SAndroid Build Coastguard Worker 
570*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(output, expected));
571*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
572*da0073e9SAndroid Build Coastguard Worker }
573*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Flatten)574*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Flatten) {
575*da0073e9SAndroid Build Coastguard Worker   Flatten flatten;
576*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
577*da0073e9SAndroid Build Coastguard Worker       {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
578*da0073e9SAndroid Build Coastguard Worker   auto output = flatten->forward(input);
579*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat);
580*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
581*da0073e9SAndroid Build Coastguard Worker 
582*da0073e9SAndroid Build Coastguard Worker   s.backward();
583*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(output, expected));
584*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
585*da0073e9SAndroid Build Coastguard Worker 
586*da0073e9SAndroid Build Coastguard Worker   // Testing with optional arguments start_dim and end_dim
587*da0073e9SAndroid Build Coastguard Worker   Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3));
588*da0073e9SAndroid Build Coastguard Worker   input = torch::tensor(
589*da0073e9SAndroid Build Coastguard Worker       {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
590*da0073e9SAndroid Build Coastguard Worker        {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}},
591*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat)
592*da0073e9SAndroid Build Coastguard Worker           .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2)
593*da0073e9SAndroid Build Coastguard Worker 
594*da0073e9SAndroid Build Coastguard Worker   output = flatten_optional_dims->forward(input);
595*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
596*da0073e9SAndroid Build Coastguard Worker       {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}},
597*da0073e9SAndroid Build Coastguard Worker       torch::kFloat); // Tensor with sizes (2, 2, 4)
598*da0073e9SAndroid Build Coastguard Worker 
599*da0073e9SAndroid Build Coastguard Worker   s = output.sum();
600*da0073e9SAndroid Build Coastguard Worker   s.backward();
601*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(output, expected));
602*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
603*da0073e9SAndroid Build Coastguard Worker }
604*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Unflatten)605*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Unflatten) {
606*da0073e9SAndroid Build Coastguard Worker   // Non-named tensor
607*da0073e9SAndroid Build Coastguard Worker   Unflatten unflatten(UnflattenOptions(0, {2, 2}));
608*da0073e9SAndroid Build Coastguard Worker   auto output = unflatten->forward(torch::tensor({1, 2, 3, 4}));
609*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({{1, 2}, {3, 4}});
610*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(output, expected));
611*da0073e9SAndroid Build Coastguard Worker 
612*da0073e9SAndroid Build Coastguard Worker   // Named tensor
613*da0073e9SAndroid Build Coastguard Worker   auto make_dimnames = [](std::vector<std::string> names) {
614*da0073e9SAndroid Build Coastguard Worker     std::vector<torch::Dimname> dimnames;
615*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-for-range-copy)
616*da0073e9SAndroid Build Coastguard Worker     for (auto name : names) {
617*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
618*da0073e9SAndroid Build Coastguard Worker       dimnames.push_back(
619*da0073e9SAndroid Build Coastguard Worker           torch::Dimname::fromSymbol(torch::Symbol::dimname(name)));
620*da0073e9SAndroid Build Coastguard Worker     }
621*da0073e9SAndroid Build Coastguard Worker     return dimnames;
622*da0073e9SAndroid Build Coastguard Worker   };
623*da0073e9SAndroid Build Coastguard Worker 
624*da0073e9SAndroid Build Coastguard Worker   unflatten = Unflatten(UnflattenOptions(
625*da0073e9SAndroid Build Coastguard Worker       "B",
626*da0073e9SAndroid Build Coastguard Worker       {std::pair<std::string, int64_t>{"B1", 2},
627*da0073e9SAndroid Build Coastguard Worker        std::pair<std::string, int64_t>{"B2", 2}}));
628*da0073e9SAndroid Build Coastguard Worker   output = unflatten->forward(
629*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{1, 2, 3, 4}}).refine_names(make_dimnames({"A", "B"})));
630*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor({{{1, 2}, {3, 4}}})
631*da0073e9SAndroid Build Coastguard Worker                  .refine_names(make_dimnames({"A", "B1", "B2"}));
632*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::equal(output, expected));
633*da0073e9SAndroid Build Coastguard Worker }
634*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool1d)635*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool1d) {
636*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool1d model(3);
637*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
638*da0073e9SAndroid Build Coastguard Worker       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
639*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
640*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
641*da0073e9SAndroid Build Coastguard Worker 
642*da0073e9SAndroid Build Coastguard Worker   s.backward();
643*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
644*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
645*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
646*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
647*da0073e9SAndroid Build Coastguard Worker }
648*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool1dReturnIndices)649*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) {
650*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool1d model(3);
651*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
652*da0073e9SAndroid Build Coastguard Worker       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
653*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
654*da0073e9SAndroid Build Coastguard Worker 
655*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
656*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
657*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
658*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
659*da0073e9SAndroid Build Coastguard Worker       torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong)));
660*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 3}));
661*da0073e9SAndroid Build Coastguard Worker }
662*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool2dEven)663*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool2dEven) {
664*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool2d model(3);
665*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 50);
666*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 5}).set_requires_grad(true);
667*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
668*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
669*da0073e9SAndroid Build Coastguard Worker 
670*da0073e9SAndroid Build Coastguard Worker   s.backward();
671*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
672*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
673*da0073e9SAndroid Build Coastguard Worker       y,
674*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
675*da0073e9SAndroid Build Coastguard Worker           {
676*da0073e9SAndroid Build Coastguard Worker               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
677*da0073e9SAndroid Build Coastguard Worker               {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
678*da0073e9SAndroid Build Coastguard Worker           },
679*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
680*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
681*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
682*da0073e9SAndroid Build Coastguard Worker }
683*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool2dUneven)684*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) {
685*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
686*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 40);
687*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 4}).set_requires_grad(true);
688*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
689*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
690*da0073e9SAndroid Build Coastguard Worker 
691*da0073e9SAndroid Build Coastguard Worker   s.backward();
692*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
693*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
694*da0073e9SAndroid Build Coastguard Worker       y,
695*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
696*da0073e9SAndroid Build Coastguard Worker           {
697*da0073e9SAndroid Build Coastguard Worker               {{5, 7}, {13, 15}, {17, 19}},
698*da0073e9SAndroid Build Coastguard Worker               {{25, 27}, {33, 35}, {37, 39}},
699*da0073e9SAndroid Build Coastguard Worker           },
700*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
701*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
702*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
703*da0073e9SAndroid Build Coastguard Worker }
704*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesEven)705*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) {
706*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool2d model(3);
707*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 50);
708*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 5}).set_requires_grad(true);
709*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
710*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
711*da0073e9SAndroid Build Coastguard Worker 
712*da0073e9SAndroid Build Coastguard Worker   s.backward();
713*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
714*da0073e9SAndroid Build Coastguard Worker 
715*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
716*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
717*da0073e9SAndroid Build Coastguard Worker       y,
718*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
719*da0073e9SAndroid Build Coastguard Worker           {
720*da0073e9SAndroid Build Coastguard Worker               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
721*da0073e9SAndroid Build Coastguard Worker               {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
722*da0073e9SAndroid Build Coastguard Worker           },
723*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
724*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
725*da0073e9SAndroid Build Coastguard Worker 
726*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.ndimension(), 3);
727*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
728*da0073e9SAndroid Build Coastguard Worker       indices,
729*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
730*da0073e9SAndroid Build Coastguard Worker           {
731*da0073e9SAndroid Build Coastguard Worker               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
732*da0073e9SAndroid Build Coastguard Worker               {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
733*da0073e9SAndroid Build Coastguard Worker           },
734*da0073e9SAndroid Build Coastguard Worker           torch::kLong)));
735*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 3}));
736*da0073e9SAndroid Build Coastguard Worker }
737*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool2dReturnIndicesUneven)738*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) {
739*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
740*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 40);
741*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 4}).set_requires_grad(true);
742*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
743*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
744*da0073e9SAndroid Build Coastguard Worker 
745*da0073e9SAndroid Build Coastguard Worker   s.backward();
746*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
747*da0073e9SAndroid Build Coastguard Worker 
748*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
749*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
750*da0073e9SAndroid Build Coastguard Worker       y,
751*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
752*da0073e9SAndroid Build Coastguard Worker           {
753*da0073e9SAndroid Build Coastguard Worker               {{5, 7}, {13, 15}, {17, 19}},
754*da0073e9SAndroid Build Coastguard Worker               {{25, 27}, {33, 35}, {37, 39}},
755*da0073e9SAndroid Build Coastguard Worker           },
756*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
757*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
758*da0073e9SAndroid Build Coastguard Worker 
759*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.ndimension(), 3);
760*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
761*da0073e9SAndroid Build Coastguard Worker       indices,
762*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
763*da0073e9SAndroid Build Coastguard Worker           {
764*da0073e9SAndroid Build Coastguard Worker               {{5, 7}, {13, 15}, {17, 19}},
765*da0073e9SAndroid Build Coastguard Worker               {{5, 7}, {13, 15}, {17, 19}},
766*da0073e9SAndroid Build Coastguard Worker           },
767*da0073e9SAndroid Build Coastguard Worker           torch::kLong)));
768*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 2}));
769*da0073e9SAndroid Build Coastguard Worker }
770*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool3d)771*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool3d) {
772*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool3d model(3);
773*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 64);
774*da0073e9SAndroid Build Coastguard Worker   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
775*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
776*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
777*da0073e9SAndroid Build Coastguard Worker 
778*da0073e9SAndroid Build Coastguard Worker   s.backward();
779*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
780*da0073e9SAndroid Build Coastguard Worker 
781*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
782*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
783*da0073e9SAndroid Build Coastguard Worker       y,
784*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
785*da0073e9SAndroid Build Coastguard Worker           {
786*da0073e9SAndroid Build Coastguard Worker               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
787*da0073e9SAndroid Build Coastguard Worker               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
788*da0073e9SAndroid Build Coastguard Worker               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
789*da0073e9SAndroid Build Coastguard Worker           },
790*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
791*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
792*da0073e9SAndroid Build Coastguard Worker }
793*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveMaxPool3dReturnIndices)794*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) {
795*da0073e9SAndroid Build Coastguard Worker   AdaptiveMaxPool3d model(3);
796*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 64);
797*da0073e9SAndroid Build Coastguard Worker   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
798*da0073e9SAndroid Build Coastguard Worker   auto [y, indices] = model->forward_with_indices(x);
799*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
800*da0073e9SAndroid Build Coastguard Worker 
801*da0073e9SAndroid Build Coastguard Worker   s.backward();
802*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
803*da0073e9SAndroid Build Coastguard Worker 
804*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
805*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
806*da0073e9SAndroid Build Coastguard Worker       y,
807*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
808*da0073e9SAndroid Build Coastguard Worker           {
809*da0073e9SAndroid Build Coastguard Worker               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
810*da0073e9SAndroid Build Coastguard Worker               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
811*da0073e9SAndroid Build Coastguard Worker               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
812*da0073e9SAndroid Build Coastguard Worker           },
813*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
814*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
815*da0073e9SAndroid Build Coastguard Worker 
816*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.ndimension(), 4);
817*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
818*da0073e9SAndroid Build Coastguard Worker       indices,
819*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
820*da0073e9SAndroid Build Coastguard Worker           {
821*da0073e9SAndroid Build Coastguard Worker               {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
822*da0073e9SAndroid Build Coastguard Worker               {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
823*da0073e9SAndroid Build Coastguard Worker               {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
824*da0073e9SAndroid Build Coastguard Worker           },
825*da0073e9SAndroid Build Coastguard Worker           torch::kLong)));
826*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
827*da0073e9SAndroid Build Coastguard Worker }
828*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveAvgPool1d)829*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveAvgPool1d) {
830*da0073e9SAndroid Build Coastguard Worker   AdaptiveAvgPool1d model(3);
831*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
832*da0073e9SAndroid Build Coastguard Worker       {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
833*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
834*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
835*da0073e9SAndroid Build Coastguard Worker 
836*da0073e9SAndroid Build Coastguard Worker   s.backward();
837*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
838*da0073e9SAndroid Build Coastguard Worker 
839*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
840*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
841*da0073e9SAndroid Build Coastguard Worker       torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat)));
842*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
843*da0073e9SAndroid Build Coastguard Worker }
844*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveAvgPool2dEven)845*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
846*da0073e9SAndroid Build Coastguard Worker   AdaptiveAvgPool2d model(3);
847*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 50);
848*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 5}).set_requires_grad(true);
849*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
850*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
851*da0073e9SAndroid Build Coastguard Worker 
852*da0073e9SAndroid Build Coastguard Worker   s.backward();
853*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
854*da0073e9SAndroid Build Coastguard Worker 
855*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
856*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
857*da0073e9SAndroid Build Coastguard Worker       y,
858*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
859*da0073e9SAndroid Build Coastguard Worker           {
860*da0073e9SAndroid Build Coastguard Worker               {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}},
861*da0073e9SAndroid Build Coastguard Worker               {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}},
862*da0073e9SAndroid Build Coastguard Worker           },
863*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
864*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
865*da0073e9SAndroid Build Coastguard Worker }
866*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveAvgPool2dUneven)867*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
868*da0073e9SAndroid Build Coastguard Worker   AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
869*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 40);
870*da0073e9SAndroid Build Coastguard Worker   x.resize_({2, 5, 4}).set_requires_grad(true);
871*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
872*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
873*da0073e9SAndroid Build Coastguard Worker 
874*da0073e9SAndroid Build Coastguard Worker   s.backward();
875*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
876*da0073e9SAndroid Build Coastguard Worker 
877*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
878*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
879*da0073e9SAndroid Build Coastguard Worker       y,
880*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
881*da0073e9SAndroid Build Coastguard Worker           {
882*da0073e9SAndroid Build Coastguard Worker               {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}},
883*da0073e9SAndroid Build Coastguard Worker               {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}},
884*da0073e9SAndroid Build Coastguard Worker           },
885*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
886*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
887*da0073e9SAndroid Build Coastguard Worker }
888*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveAvgPool3d)889*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveAvgPool3d) {
890*da0073e9SAndroid Build Coastguard Worker   AdaptiveAvgPool3d model(3);
891*da0073e9SAndroid Build Coastguard Worker   auto x = torch::arange(0., 64);
892*da0073e9SAndroid Build Coastguard Worker   x.resize_({1, 4, 4, 4}).set_requires_grad(true);
893*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
894*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
895*da0073e9SAndroid Build Coastguard Worker 
896*da0073e9SAndroid Build Coastguard Worker   s.backward();
897*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
898*da0073e9SAndroid Build Coastguard Worker 
899*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
900*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
901*da0073e9SAndroid Build Coastguard Worker       y,
902*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
903*da0073e9SAndroid Build Coastguard Worker           {
904*da0073e9SAndroid Build Coastguard Worker               {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}},
905*da0073e9SAndroid Build Coastguard Worker               {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}},
906*da0073e9SAndroid Build Coastguard Worker               {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}},
907*da0073e9SAndroid Build Coastguard Worker           },
908*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
909*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
910*da0073e9SAndroid Build Coastguard Worker }
911*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxUnpool1d)912*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxUnpool1d) {
913*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
914*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
915*da0073e9SAndroid Build Coastguard Worker       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
916*da0073e9SAndroid Build Coastguard Worker   auto model = MaxUnpool1d{3};
917*da0073e9SAndroid Build Coastguard Worker   auto y = model->forward(x, indices);
918*da0073e9SAndroid Build Coastguard Worker 
919*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
920*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
921*da0073e9SAndroid Build Coastguard Worker       y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
922*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
923*da0073e9SAndroid Build Coastguard Worker 
924*da0073e9SAndroid Build Coastguard Worker   indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
925*da0073e9SAndroid Build Coastguard Worker   x = torch::tensor(
926*da0073e9SAndroid Build Coastguard Worker       {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
927*da0073e9SAndroid Build Coastguard Worker   model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)};
928*da0073e9SAndroid Build Coastguard Worker   y = model->forward(x, indices, std::vector<int64_t>({1, 1, 5}));
929*da0073e9SAndroid Build Coastguard Worker 
930*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 3);
931*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
932*da0073e9SAndroid Build Coastguard Worker       torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
933*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
934*da0073e9SAndroid Build Coastguard Worker }
935*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool1d_MaxUnpool1d)936*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) {
937*da0073e9SAndroid Build Coastguard Worker   MaxPool1d pool{MaxPool1dOptions(2).stride(2)};
938*da0073e9SAndroid Build Coastguard Worker   MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)};
939*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat);
940*da0073e9SAndroid Build Coastguard Worker   auto [output, indices] = pool->forward_with_indices(input);
941*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
942*da0073e9SAndroid Build Coastguard Worker       unpool(output, indices),
943*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
944*da0073e9SAndroid Build Coastguard Worker 
945*da0073e9SAndroid Build Coastguard Worker   // Example showcasing the use of output_size
946*da0073e9SAndroid Build Coastguard Worker   input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat);
947*da0073e9SAndroid Build Coastguard Worker   std::tie(output, indices) = pool->forward_with_indices(input);
948*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
949*da0073e9SAndroid Build Coastguard Worker       unpool(output, indices, input.sizes().vec()),
950*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat)));
951*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
952*da0073e9SAndroid Build Coastguard Worker       unpool(output, indices),
953*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
954*da0073e9SAndroid Build Coastguard Worker }
955*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxUnpool2d)956*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxUnpool2d) {
957*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor(
958*da0073e9SAndroid Build Coastguard Worker       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
959*da0073e9SAndroid Build Coastguard Worker        {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
960*da0073e9SAndroid Build Coastguard Worker       torch::kLong);
961*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
962*da0073e9SAndroid Build Coastguard Worker       {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
963*da0073e9SAndroid Build Coastguard Worker        {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
964*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
965*da0073e9SAndroid Build Coastguard Worker   auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)};
966*da0073e9SAndroid Build Coastguard Worker   auto y = model->forward(x, indices);
967*da0073e9SAndroid Build Coastguard Worker 
968*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 4);
969*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
970*da0073e9SAndroid Build Coastguard Worker       y,
971*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
972*da0073e9SAndroid Build Coastguard Worker           {{{{0, 0, 0, 0, 0},
973*da0073e9SAndroid Build Coastguard Worker              {0, 6, 0, 8, 9},
974*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 0, 0},
975*da0073e9SAndroid Build Coastguard Worker              {0, 16, 0, 18, 19},
976*da0073e9SAndroid Build Coastguard Worker              {0, 21, 0, 23, 24}}},
977*da0073e9SAndroid Build Coastguard Worker            {{{0, 0, 0, 0, 0},
978*da0073e9SAndroid Build Coastguard Worker              {0, 31, 0, 33, 34},
979*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 0, 0},
980*da0073e9SAndroid Build Coastguard Worker              {0, 41, 0, 43, 44},
981*da0073e9SAndroid Build Coastguard Worker              {0, 46, 0, 48, 49}}}},
982*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
983*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
984*da0073e9SAndroid Build Coastguard Worker }
985*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool2d_MaxUnpool2d)986*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) {
987*da0073e9SAndroid Build Coastguard Worker   MaxPool2d pool{MaxPool2dOptions(2).stride(2)};
988*da0073e9SAndroid Build Coastguard Worker   MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)};
989*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
990*da0073e9SAndroid Build Coastguard Worker       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}},
991*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
992*da0073e9SAndroid Build Coastguard Worker   auto [output, indices] = pool->forward_with_indices(input);
993*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
994*da0073e9SAndroid Build Coastguard Worker       unpool(output, indices),
995*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
996*da0073e9SAndroid Build Coastguard Worker           {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}},
997*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
998*da0073e9SAndroid Build Coastguard Worker 
999*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1000*da0073e9SAndroid Build Coastguard Worker       unpool(output, indices, std::vector<int64_t>{1, 1, 5, 5}),
1001*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1002*da0073e9SAndroid Build Coastguard Worker           {{{{0, 0, 0, 0, 0},
1003*da0073e9SAndroid Build Coastguard Worker              {6, 0, 8, 0, 0},
1004*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 14, 0},
1005*da0073e9SAndroid Build Coastguard Worker              {16, 0, 0, 0, 0},
1006*da0073e9SAndroid Build Coastguard Worker              {0, 0, 0, 0, 0}}}},
1007*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1008*da0073e9SAndroid Build Coastguard Worker }
1009*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxUnpool3d)1010*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxUnpool3d) {
1011*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
1012*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
1013*da0073e9SAndroid Build Coastguard Worker       {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1014*da0073e9SAndroid Build Coastguard Worker   auto model = MaxUnpool3d{3};
1015*da0073e9SAndroid Build Coastguard Worker   auto y = model->forward(x, indices);
1016*da0073e9SAndroid Build Coastguard Worker 
1017*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 5);
1018*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1019*da0073e9SAndroid Build Coastguard Worker       y,
1020*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1021*da0073e9SAndroid Build Coastguard Worker           {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1022*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1023*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1024*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1025*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1026*da0073e9SAndroid Build Coastguard Worker }
1027*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxUnpool3dOutputSize)1028*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxUnpool3dOutputSize) {
1029*da0073e9SAndroid Build Coastguard Worker   auto indices = torch::tensor(
1030*da0073e9SAndroid Build Coastguard Worker       {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong);
1031*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
1032*da0073e9SAndroid Build Coastguard Worker       {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}},
1033*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
1034*da0073e9SAndroid Build Coastguard Worker   auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)};
1035*da0073e9SAndroid Build Coastguard Worker   auto y = model->forward(x, indices, std::vector<int64_t>({1, 1, 4, 4, 4}));
1036*da0073e9SAndroid Build Coastguard Worker 
1037*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.dim(), 5);
1038*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1039*da0073e9SAndroid Build Coastguard Worker       y,
1040*da0073e9SAndroid Build Coastguard Worker       torch::tensor(
1041*da0073e9SAndroid Build Coastguard Worker           {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1042*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}},
1043*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1044*da0073e9SAndroid Build Coastguard Worker              {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}},
1045*da0073e9SAndroid Build Coastguard Worker           torch::kFloat)));
1046*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 4, 4, 4}));
1047*da0073e9SAndroid Build Coastguard Worker }
1048*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MaxPool3d_MaxUnpool3d)1049*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) {
1050*da0073e9SAndroid Build Coastguard Worker   MaxPool3d pool{MaxPool3dOptions(3).stride(2)};
1051*da0073e9SAndroid Build Coastguard Worker   MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)};
1052*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({20, 16, 51, 33, 15});
1053*da0073e9SAndroid Build Coastguard Worker   auto [output, indices] = pool->forward_with_indices(input);
1054*da0073e9SAndroid Build Coastguard Worker   auto unpooled_output = unpool(output, indices);
1055*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
1056*da0073e9SAndroid Build Coastguard Worker       unpooled_output.sizes(), std::vector<int64_t>({20, 16, 51, 33, 15}));
1057*da0073e9SAndroid Build Coastguard Worker }
1058*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Linear)1059*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Linear) {
1060*da0073e9SAndroid Build Coastguard Worker   {
1061*da0073e9SAndroid Build Coastguard Worker     Linear model(5, 2);
1062*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({10, 5}, torch::requires_grad());
1063*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
1064*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
1065*da0073e9SAndroid Build Coastguard Worker 
1066*da0073e9SAndroid Build Coastguard Worker     s.backward();
1067*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 2);
1068*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1069*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(0), 10);
1070*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(1), 2);
1071*da0073e9SAndroid Build Coastguard Worker 
1072*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1073*da0073e9SAndroid Build Coastguard Worker 
1074*da0073e9SAndroid Build Coastguard Worker     auto y_exp = torch::addmm(model->bias, x, model->weight.t());
1075*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1076*da0073e9SAndroid Build Coastguard Worker   }
1077*da0073e9SAndroid Build Coastguard Worker   {
1078*da0073e9SAndroid Build Coastguard Worker     Linear model(LinearOptions(5, 2).bias(false));
1079*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({10, 5}, torch::requires_grad());
1080*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
1081*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
1082*da0073e9SAndroid Build Coastguard Worker 
1083*da0073e9SAndroid Build Coastguard Worker     s.backward();
1084*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 2);
1085*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1086*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(0), 10);
1087*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(1), 2);
1088*da0073e9SAndroid Build Coastguard Worker 
1089*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1090*da0073e9SAndroid Build Coastguard Worker 
1091*da0073e9SAndroid Build Coastguard Worker     auto y_exp = torch::mm(x, model->weight.t());
1092*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
1093*da0073e9SAndroid Build Coastguard Worker   }
1094*da0073e9SAndroid Build Coastguard Worker }
1095*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LocalResponseNorm)1096*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LocalResponseNorm) {
1097*da0073e9SAndroid Build Coastguard Worker   {
1098*da0073e9SAndroid Build Coastguard Worker     LocalResponseNorm model(LocalResponseNormOptions(2));
1099*da0073e9SAndroid Build Coastguard Worker     const auto x =
1100*da0073e9SAndroid Build Coastguard Worker         torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2});
1101*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
1102*da0073e9SAndroid Build Coastguard Worker     const auto y_exp = torch::tensor(
1103*da0073e9SAndroid Build Coastguard Worker         {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}},
1104*da0073e9SAndroid Build Coastguard Worker 
1105*da0073e9SAndroid Build Coastguard Worker           {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}},
1106*da0073e9SAndroid Build Coastguard Worker 
1107*da0073e9SAndroid Build Coastguard Worker           {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}},
1108*da0073e9SAndroid Build Coastguard Worker 
1109*da0073e9SAndroid Build Coastguard Worker          {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}},
1110*da0073e9SAndroid Build Coastguard Worker 
1111*da0073e9SAndroid Build Coastguard Worker           {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}},
1112*da0073e9SAndroid Build Coastguard Worker 
1113*da0073e9SAndroid Build Coastguard Worker           {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}},
1114*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
1115*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
1116*da0073e9SAndroid Build Coastguard Worker 
1117*da0073e9SAndroid Build Coastguard Worker     s.backward();
1118*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 4);
1119*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1120*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), x.sizes());
1121*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1122*da0073e9SAndroid Build Coastguard Worker   }
1123*da0073e9SAndroid Build Coastguard Worker }
1124*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LayerNorm)1125*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LayerNorm) {
1126*da0073e9SAndroid Build Coastguard Worker   LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
1127*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({2, 2}, torch::requires_grad());
1128*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1129*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::layer_norm(x, {2, 2}, model->weight, model->bias, 2e-5);
1130*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1131*da0073e9SAndroid Build Coastguard Worker 
1132*da0073e9SAndroid Build Coastguard Worker   s.backward();
1133*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1134*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1135*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1136*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(i), 2);
1137*da0073e9SAndroid Build Coastguard Worker   }
1138*da0073e9SAndroid Build Coastguard Worker 
1139*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2 * 2);
1140*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1141*da0073e9SAndroid Build Coastguard Worker }
1142*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,GroupNorm)1143*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, GroupNorm) {
1144*da0073e9SAndroid Build Coastguard Worker   GroupNorm model(GroupNormOptions(2, 2).eps(2e-5));
1145*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({2, 2}, torch::requires_grad());
1146*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1147*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::group_norm(x, 2, model->weight, model->bias, 2e-5);
1148*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1149*da0073e9SAndroid Build Coastguard Worker 
1150*da0073e9SAndroid Build Coastguard Worker   s.backward();
1151*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1152*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1153*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
1154*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(i), 2);
1155*da0073e9SAndroid Build Coastguard Worker   }
1156*da0073e9SAndroid Build Coastguard Worker 
1157*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2);
1158*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
1159*da0073e9SAndroid Build Coastguard Worker }
1160*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Bilinear)1161*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Bilinear) {
1162*da0073e9SAndroid Build Coastguard Worker   Bilinear model(5, 3, 2);
1163*da0073e9SAndroid Build Coastguard Worker   auto x1 = torch::randn({10, 5}, torch::requires_grad());
1164*da0073e9SAndroid Build Coastguard Worker   auto x2 = torch::randn({10, 3}, torch::requires_grad());
1165*da0073e9SAndroid Build Coastguard Worker   auto y = model(x1, x2);
1166*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1167*da0073e9SAndroid Build Coastguard Worker 
1168*da0073e9SAndroid Build Coastguard Worker   s.backward();
1169*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1170*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1171*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 10);
1172*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 2);
1173*da0073e9SAndroid Build Coastguard Worker 
1174*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2 * 5 * 3);
1175*da0073e9SAndroid Build Coastguard Worker }
1176*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Fold)1177*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Fold) {
1178*da0073e9SAndroid Build Coastguard Worker   {
1179*da0073e9SAndroid Build Coastguard Worker     Fold model(FoldOptions({3, 2}, {2, 2}));
1180*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::requires_grad());
1181*da0073e9SAndroid Build Coastguard Worker     auto output = model(input);
1182*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
1183*da0073e9SAndroid Build Coastguard Worker         {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1184*da0073e9SAndroid Build Coastguard Worker           {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1185*da0073e9SAndroid Build Coastguard Worker           {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1186*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
1187*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
1188*da0073e9SAndroid Build Coastguard Worker     s.backward();
1189*da0073e9SAndroid Build Coastguard Worker 
1190*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1191*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1192*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
1193*da0073e9SAndroid Build Coastguard Worker   }
1194*da0073e9SAndroid Build Coastguard Worker   {
1195*da0073e9SAndroid Build Coastguard Worker     // input wrong dimension
1196*da0073e9SAndroid Build Coastguard Worker     Fold model(FoldOptions({8, 8}, {3, 3}));
1197*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
1198*da0073e9SAndroid Build Coastguard Worker         model(torch::randn({1, 3, 16, 16})),
1199*da0073e9SAndroid Build Coastguard Worker         "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported (got 4D)");
1200*da0073e9SAndroid Build Coastguard Worker   }
1201*da0073e9SAndroid Build Coastguard Worker }
1202*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Unfold)1203*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Unfold) {
1204*da0073e9SAndroid Build Coastguard Worker   {
1205*da0073e9SAndroid Build Coastguard Worker     Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2));
1206*da0073e9SAndroid Build Coastguard Worker     auto input =
1207*da0073e9SAndroid Build Coastguard Worker         torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3});
1208*da0073e9SAndroid Build Coastguard Worker     auto output = model(input);
1209*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
1210*da0073e9SAndroid Build Coastguard Worker         {{{0.0, 0.0, 0.0, 6.0},
1211*da0073e9SAndroid Build Coastguard Worker           {0.0, 0.0, 5.0, 7.0},
1212*da0073e9SAndroid Build Coastguard Worker           {0.0, 3.0, 0.0, 0.0},
1213*da0073e9SAndroid Build Coastguard Worker           {2.0, 4.0, 0.0, 0.0},
1214*da0073e9SAndroid Build Coastguard Worker           {0.0, 0.0, 0.0, 12.0},
1215*da0073e9SAndroid Build Coastguard Worker           {0.0, 0.0, 11.0, 13.0},
1216*da0073e9SAndroid Build Coastguard Worker           {0.0, 9.0, 0.0, 0.0},
1217*da0073e9SAndroid Build Coastguard Worker           {8.0, 10.0, 0.0, 0.0}}},
1218*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
1219*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
1220*da0073e9SAndroid Build Coastguard Worker     s.backward();
1221*da0073e9SAndroid Build Coastguard Worker 
1222*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
1223*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1224*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
1225*da0073e9SAndroid Build Coastguard Worker   }
1226*da0073e9SAndroid Build Coastguard Worker   {
1227*da0073e9SAndroid Build Coastguard Worker     // input wrong dimension
1228*da0073e9SAndroid Build Coastguard Worker     Unfold model(UnfoldOptions({2, 4}));
1229*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
1230*da0073e9SAndroid Build Coastguard Worker         model(torch::randn({1, 5, 2})),
1231*da0073e9SAndroid Build Coastguard Worker         "Input Error: Only 4D input Tensors are supported (got 3D)");
1232*da0073e9SAndroid Build Coastguard Worker   }
1233*da0073e9SAndroid Build Coastguard Worker   {
1234*da0073e9SAndroid Build Coastguard Worker     // calculated output shape is too small
1235*da0073e9SAndroid Build Coastguard Worker     Unfold model(UnfoldOptions({2, 3}));
1236*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
1237*da0073e9SAndroid Build Coastguard Worker         model(torch::randn({1, 2, 2, 2})),
1238*da0073e9SAndroid Build Coastguard Worker         "Given input with spatial size (2, 2), kernel_size=(2, 3), "
1239*da0073e9SAndroid Build Coastguard Worker         "dilation=(1, 1), padding=(0, 0), calculated shape of the array of "
1240*da0073e9SAndroid Build Coastguard Worker         "sliding blocks as (1, 0), but its components must be at least one.");
1241*da0073e9SAndroid Build Coastguard Worker   }
1242*da0073e9SAndroid Build Coastguard Worker }
1243*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SimpleContainer)1244*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SimpleContainer) {
1245*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<SimpleContainer>();
1246*da0073e9SAndroid Build Coastguard Worker   auto l1 = model->add(Linear(10, 3), "l1");
1247*da0073e9SAndroid Build Coastguard Worker   auto l2 = model->add(Linear(3, 5), "l2");
1248*da0073e9SAndroid Build Coastguard Worker   auto l3 = model->add(Linear(5, 100), "l3");
1249*da0073e9SAndroid Build Coastguard Worker 
1250*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({1000, 10}, torch::requires_grad());
1251*da0073e9SAndroid Build Coastguard Worker   x = l1(x).clamp_min(0);
1252*da0073e9SAndroid Build Coastguard Worker   x = l2(x).clamp_min(0);
1253*da0073e9SAndroid Build Coastguard Worker   x = l3(x).clamp_min(0);
1254*da0073e9SAndroid Build Coastguard Worker 
1255*da0073e9SAndroid Build Coastguard Worker   x.backward(torch::ones_like(x));
1256*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.ndimension(), 2);
1257*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.size(0), 1000);
1258*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.size(1), 100);
1259*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.min().item<float>(), 0);
1260*da0073e9SAndroid Build Coastguard Worker }
1261*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,EmbeddingBasic)1262*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, EmbeddingBasic) {
1263*da0073e9SAndroid Build Coastguard Worker   const int64_t dict_size = 10;
1264*da0073e9SAndroid Build Coastguard Worker   Embedding model(dict_size, 2);
1265*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(model->named_parameters().contains("weight"));
1266*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.ndimension(), 2);
1267*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.size(0), dict_size);
1268*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.size(1), 2);
1269*da0073e9SAndroid Build Coastguard Worker 
1270*da0073e9SAndroid Build Coastguard Worker   // Cannot get gradients to change indices (input) - only for embedding
1271*da0073e9SAndroid Build Coastguard Worker   // params
1272*da0073e9SAndroid Build Coastguard Worker   auto x = torch::full({10}, dict_size - 1, torch::kInt64);
1273*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1274*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1275*da0073e9SAndroid Build Coastguard Worker 
1276*da0073e9SAndroid Build Coastguard Worker   s.backward();
1277*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1278*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1279*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 10);
1280*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 2);
1281*da0073e9SAndroid Build Coastguard Worker 
1282*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size);
1283*da0073e9SAndroid Build Coastguard Worker }
1284*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,EmbeddingList)1285*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, EmbeddingList) {
1286*da0073e9SAndroid Build Coastguard Worker   Embedding model(6, 4);
1287*da0073e9SAndroid Build Coastguard Worker   auto x = torch::full({2, 3}, 5, torch::kInt64);
1288*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1289*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1290*da0073e9SAndroid Build Coastguard Worker 
1291*da0073e9SAndroid Build Coastguard Worker   s.backward();
1292*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
1293*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 2);
1294*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 3);
1295*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(2), 4);
1296*da0073e9SAndroid Build Coastguard Worker }
1297*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,EmbeddingFromPretrained)1298*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, EmbeddingFromPretrained) {
1299*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1300*da0073e9SAndroid Build Coastguard Worker   Embedding embedding = torch::nn::Embedding::from_pretrained(weight);
1301*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor({1}, torch::kLong);
1302*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1303*da0073e9SAndroid Build Coastguard Worker       embedding(input), torch::tensor({4.0000, 5.1000, 6.3000})));
1304*da0073e9SAndroid Build Coastguard Worker }
1305*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,EmbeddingBagFromPretrained)1306*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, EmbeddingBagFromPretrained) {
1307*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1308*da0073e9SAndroid Build Coastguard Worker   EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight);
1309*da0073e9SAndroid Build Coastguard Worker   auto input = torch::zeros({{1, 2}}, torch::kLong);
1310*da0073e9SAndroid Build Coastguard Worker   input[0] = torch::tensor({1, 0});
1311*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
1312*da0073e9SAndroid Build Coastguard Worker       embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500})));
1313*da0073e9SAndroid Build Coastguard Worker }
1314*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AlphaDropout)1315*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AlphaDropout) {
1316*da0073e9SAndroid Build Coastguard Worker   AlphaDropout alpha_dropout(0.5);
1317*da0073e9SAndroid Build Coastguard Worker   torch::Tensor x = torch::ones(100, torch::requires_grad());
1318*da0073e9SAndroid Build Coastguard Worker   torch::Tensor y = alpha_dropout(x);
1319*da0073e9SAndroid Build Coastguard Worker 
1320*da0073e9SAndroid Build Coastguard Worker   y.backward(torch::ones_like(y));
1321*da0073e9SAndroid Build Coastguard Worker 
1322*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 1);
1323*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 100);
1324*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(y.sum().item<float>(), 130); // Probably
1325*da0073e9SAndroid Build Coastguard Worker   ASSERT_GT(y.sum().item<float>(), 40); // Probably
1326*da0073e9SAndroid Build Coastguard Worker 
1327*da0073e9SAndroid Build Coastguard Worker   alpha_dropout->eval();
1328*da0073e9SAndroid Build Coastguard Worker   y = alpha_dropout(x);
1329*da0073e9SAndroid Build Coastguard Worker 
1330*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sum().item<float>(), 100);
1331*da0073e9SAndroid Build Coastguard Worker }
1332*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FeatureAlphaDropout)1333*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FeatureAlphaDropout) {
1334*da0073e9SAndroid Build Coastguard Worker   FeatureAlphaDropout feature_alpha_dropout(0.5);
1335*da0073e9SAndroid Build Coastguard Worker   torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1336*da0073e9SAndroid Build Coastguard Worker   torch::Tensor y = feature_alpha_dropout(x);
1337*da0073e9SAndroid Build Coastguard Worker 
1338*da0073e9SAndroid Build Coastguard Worker   y.backward(torch::ones_like(y));
1339*da0073e9SAndroid Build Coastguard Worker 
1340*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1341*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 10);
1342*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 10);
1343*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(y.sum().item<float>(), 130); // Probably
1344*da0073e9SAndroid Build Coastguard Worker   ASSERT_GT(y.sum().item<float>(), 40); // Probably
1345*da0073e9SAndroid Build Coastguard Worker 
1346*da0073e9SAndroid Build Coastguard Worker   feature_alpha_dropout->eval();
1347*da0073e9SAndroid Build Coastguard Worker   y = feature_alpha_dropout(x);
1348*da0073e9SAndroid Build Coastguard Worker 
1349*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sum().item<float>(), 100);
1350*da0073e9SAndroid Build Coastguard Worker }
1351*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Dropout)1352*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Dropout) {
1353*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1354*da0073e9SAndroid Build Coastguard Worker     Dropout dropout(DropoutOptions(0.5).inplace(inplace));
1355*da0073e9SAndroid Build Coastguard Worker     torch::Tensor x = torch::ones(100);
1356*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
1357*da0073e9SAndroid Build Coastguard Worker       x.requires_grad_(true);
1358*da0073e9SAndroid Build Coastguard Worker     }
1359*da0073e9SAndroid Build Coastguard Worker     torch::Tensor y = dropout(x);
1360*da0073e9SAndroid Build Coastguard Worker 
1361*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 1);
1362*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(0), 100);
1363*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(y.sum().item<float>(), 130); // Probably
1364*da0073e9SAndroid Build Coastguard Worker     ASSERT_GT(y.sum().item<float>(), 70); // Probably
1365*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1366*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(y.allclose(x));
1367*da0073e9SAndroid Build Coastguard Worker     } else {
1368*da0073e9SAndroid Build Coastguard Worker       y.backward(torch::ones_like(y));
1369*da0073e9SAndroid Build Coastguard Worker     }
1370*da0073e9SAndroid Build Coastguard Worker 
1371*da0073e9SAndroid Build Coastguard Worker     dropout->eval();
1372*da0073e9SAndroid Build Coastguard Worker     y = dropout(torch::ones(100));
1373*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sum().item<float>(), 100);
1374*da0073e9SAndroid Build Coastguard Worker   }
1375*da0073e9SAndroid Build Coastguard Worker }
1376*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Dropout2d)1377*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Dropout2d) {
1378*da0073e9SAndroid Build Coastguard Worker   auto p = 0.5;
1379*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1380*da0073e9SAndroid Build Coastguard Worker     Dropout2d dropout(Dropout2dOptions(p).inplace(inplace));
1381*da0073e9SAndroid Build Coastguard Worker     torch::Tensor x = torch::empty({50, 50, 2, 2}).fill_(1 - p);
1382*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
1383*da0073e9SAndroid Build Coastguard Worker       x.requires_grad_(true);
1384*da0073e9SAndroid Build Coastguard Worker     }
1385*da0073e9SAndroid Build Coastguard Worker     torch::Tensor y = dropout(x);
1386*da0073e9SAndroid Build Coastguard Worker 
1387*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 4);
1388*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(0), 50);
1389*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(1), 50);
1390*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(2), 2);
1391*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(3), 2);
1392*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1393*da0073e9SAndroid Build Coastguard Worker 
1394*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1395*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(y.allclose(x));
1396*da0073e9SAndroid Build Coastguard Worker     } else {
1397*da0073e9SAndroid Build Coastguard Worker       y.backward(torch::ones_like(y));
1398*da0073e9SAndroid Build Coastguard Worker     }
1399*da0073e9SAndroid Build Coastguard Worker 
1400*da0073e9SAndroid Build Coastguard Worker     dropout->eval();
1401*da0073e9SAndroid Build Coastguard Worker     y = dropout(torch::ones({2, 2, 10, 10}));
1402*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sum().item<float>(), 400);
1403*da0073e9SAndroid Build Coastguard Worker   }
1404*da0073e9SAndroid Build Coastguard Worker }
1405*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Dropout3d)1406*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Dropout3d) {
1407*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
1408*da0073e9SAndroid Build Coastguard Worker     auto p = 0.5;
1409*da0073e9SAndroid Build Coastguard Worker     Dropout3d dropout(Dropout3dOptions(p).inplace(inplace));
1410*da0073e9SAndroid Build Coastguard Worker     torch::Tensor x = torch::empty({50, 50, 2, 2, 2}).fill_(1 - p);
1411*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
1412*da0073e9SAndroid Build Coastguard Worker       x.requires_grad_(true);
1413*da0073e9SAndroid Build Coastguard Worker     }
1414*da0073e9SAndroid Build Coastguard Worker     torch::Tensor y = dropout(x);
1415*da0073e9SAndroid Build Coastguard Worker 
1416*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 5);
1417*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(0), 50);
1418*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(1), 50);
1419*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(2), 2);
1420*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(3), 2);
1421*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.size(4), 2);
1422*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1423*da0073e9SAndroid Build Coastguard Worker 
1424*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
1425*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(y.allclose(x));
1426*da0073e9SAndroid Build Coastguard Worker     } else {
1427*da0073e9SAndroid Build Coastguard Worker       y.backward(torch::ones_like(y));
1428*da0073e9SAndroid Build Coastguard Worker     }
1429*da0073e9SAndroid Build Coastguard Worker 
1430*da0073e9SAndroid Build Coastguard Worker     dropout->eval();
1431*da0073e9SAndroid Build Coastguard Worker     y = dropout(torch::ones({4, 4, 5, 5}));
1432*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sum().item<float>(), 400);
1433*da0073e9SAndroid Build Coastguard Worker   }
1434*da0073e9SAndroid Build Coastguard Worker }
1435*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Parameters)1436*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Parameters) {
1437*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<NestedModel>();
1438*da0073e9SAndroid Build Coastguard Worker   auto parameters = model->named_parameters();
1439*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["param"].size(0), 3);
1440*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["param"].size(1), 2);
1441*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["param"].size(2), 21);
1442*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["l1.bias"].size(0), 20);
1443*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["l1.weight"].size(0), 20);
1444*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["l1.weight"].size(1), 5);
1445*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l1.bias"].size(0), 3);
1446*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l1.weight"].size(0), 3);
1447*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l1.weight"].size(1), 10);
1448*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l2.bias"].size(0), 5);
1449*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l2.weight"].size(0), 5);
1450*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l2.weight"].size(1), 3);
1451*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l3.bias"].size(0), 100);
1452*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l3.weight"].size(0), 100);
1453*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters["test.l3.weight"].size(1), 5);
1454*da0073e9SAndroid Build Coastguard Worker }
1455*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FunctionalCallsSuppliedFunction)1456*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
1457*da0073e9SAndroid Build Coastguard Worker   bool was_called = false;
1458*da0073e9SAndroid Build Coastguard Worker   auto functional = Functional([&was_called](torch::Tensor input) {
1459*da0073e9SAndroid Build Coastguard Worker     was_called = true;
1460*da0073e9SAndroid Build Coastguard Worker     return input;
1461*da0073e9SAndroid Build Coastguard Worker   });
1462*da0073e9SAndroid Build Coastguard Worker   auto output = functional(torch::ones(5, torch::requires_grad()));
1463*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(was_called);
1464*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1465*da0073e9SAndroid Build Coastguard Worker 
1466*da0073e9SAndroid Build Coastguard Worker   was_called = false;
1467*da0073e9SAndroid Build Coastguard Worker   // Use the call operator overload here.
1468*da0073e9SAndroid Build Coastguard Worker   output = functional(torch::ones(5, torch::requires_grad()));
1469*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(was_called);
1470*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1471*da0073e9SAndroid Build Coastguard Worker }
1472*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FunctionalWithTorchFunction)1473*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FunctionalWithTorchFunction) {
1474*da0073e9SAndroid Build Coastguard Worker   auto functional = Functional(torch::relu);
1475*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1476*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1477*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
1478*da0073e9SAndroid Build Coastguard Worker }
1479*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,FunctionalArgumentBinding)1480*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, FunctionalArgumentBinding) {
1481*da0073e9SAndroid Build Coastguard Worker   auto functional =
1482*da0073e9SAndroid Build Coastguard Worker       Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
1483*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
1484*da0073e9SAndroid Build Coastguard Worker }
1485*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm1dStateful)1486*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm1dStateful) {
1487*da0073e9SAndroid Build Coastguard Worker   BatchNorm1d bn(5);
1488*da0073e9SAndroid Build Coastguard Worker 
1489*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.track_running_stats());
1490*da0073e9SAndroid Build Coastguard Worker 
1491*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_mean.defined());
1492*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.dim(), 1);
1493*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.size(0), 5);
1494*da0073e9SAndroid Build Coastguard Worker 
1495*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_var.defined());
1496*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.dim(), 1);
1497*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.size(0), 5);
1498*da0073e9SAndroid Build Coastguard Worker 
1499*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->num_batches_tracked.defined());
1500*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1501*da0073e9SAndroid Build Coastguard Worker 
1502*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.affine());
1503*da0073e9SAndroid Build Coastguard Worker 
1504*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->weight.defined());
1505*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.dim(), 1);
1506*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.size(0), 5);
1507*da0073e9SAndroid Build Coastguard Worker 
1508*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->bias.defined());
1509*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.dim(), 1);
1510*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.size(0), 5);
1511*da0073e9SAndroid Build Coastguard Worker }
1512*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm1dStateless)1513*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm1dStateless) {
1514*da0073e9SAndroid Build Coastguard Worker   BatchNorm1d bn(
1515*da0073e9SAndroid Build Coastguard Worker       BatchNorm1dOptions(5).track_running_stats(false).affine(false));
1516*da0073e9SAndroid Build Coastguard Worker 
1517*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_mean.defined());
1518*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_var.defined());
1519*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->num_batches_tracked.defined());
1520*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->weight.defined());
1521*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->bias.defined());
1522*da0073e9SAndroid Build Coastguard Worker }
1523*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm1d)1524*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm1d) {
1525*da0073e9SAndroid Build Coastguard Worker   BatchNorm1d bn(5);
1526*da0073e9SAndroid Build Coastguard Worker   bn->eval();
1527*da0073e9SAndroid Build Coastguard Worker 
1528*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1529*da0073e9SAndroid Build Coastguard Worker   auto output = bn->forward(input);
1530*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1531*da0073e9SAndroid Build Coastguard Worker       {{{0.0000, 1.0000},
1532*da0073e9SAndroid Build Coastguard Worker         {2.0000, 3.0000},
1533*da0073e9SAndroid Build Coastguard Worker         {4.0000, 5.0000},
1534*da0073e9SAndroid Build Coastguard Worker         {6.0000, 7.0000},
1535*da0073e9SAndroid Build Coastguard Worker         {8.0000, 9.0000}},
1536*da0073e9SAndroid Build Coastguard Worker        {{10.0000, 10.9999},
1537*da0073e9SAndroid Build Coastguard Worker         {11.9999, 12.9999},
1538*da0073e9SAndroid Build Coastguard Worker         {13.9999, 14.9999},
1539*da0073e9SAndroid Build Coastguard Worker         {15.9999, 16.9999},
1540*da0073e9SAndroid Build Coastguard Worker         {17.9999, 18.9999}}});
1541*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1542*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1543*da0073e9SAndroid Build Coastguard Worker   s.backward();
1544*da0073e9SAndroid Build Coastguard Worker 
1545*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1546*da0073e9SAndroid Build Coastguard Worker }
1547*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm2dStateful)1548*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm2dStateful) {
1549*da0073e9SAndroid Build Coastguard Worker   BatchNorm2d bn(5);
1550*da0073e9SAndroid Build Coastguard Worker 
1551*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.track_running_stats());
1552*da0073e9SAndroid Build Coastguard Worker 
1553*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_mean.defined());
1554*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.dim(), 1);
1555*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.size(0), 5);
1556*da0073e9SAndroid Build Coastguard Worker 
1557*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_var.defined());
1558*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.dim(), 1);
1559*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.size(0), 5);
1560*da0073e9SAndroid Build Coastguard Worker 
1561*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->num_batches_tracked.defined());
1562*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1563*da0073e9SAndroid Build Coastguard Worker 
1564*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.affine());
1565*da0073e9SAndroid Build Coastguard Worker 
1566*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->weight.defined());
1567*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.dim(), 1);
1568*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.size(0), 5);
1569*da0073e9SAndroid Build Coastguard Worker 
1570*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->bias.defined());
1571*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.dim(), 1);
1572*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.size(0), 5);
1573*da0073e9SAndroid Build Coastguard Worker }
1574*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm2dStateless)1575*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm2dStateless) {
1576*da0073e9SAndroid Build Coastguard Worker   BatchNorm2d bn(
1577*da0073e9SAndroid Build Coastguard Worker       BatchNorm2dOptions(5).track_running_stats(false).affine(false));
1578*da0073e9SAndroid Build Coastguard Worker 
1579*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_mean.defined());
1580*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_var.defined());
1581*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->num_batches_tracked.defined());
1582*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->weight.defined());
1583*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->bias.defined());
1584*da0073e9SAndroid Build Coastguard Worker }
1585*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm2d)1586*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm2d) {
1587*da0073e9SAndroid Build Coastguard Worker   BatchNorm2d bn(5);
1588*da0073e9SAndroid Build Coastguard Worker   bn->eval();
1589*da0073e9SAndroid Build Coastguard Worker 
1590*da0073e9SAndroid Build Coastguard Worker   auto input =
1591*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1592*da0073e9SAndroid Build Coastguard Worker   auto output = bn->forward(input);
1593*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1594*da0073e9SAndroid Build Coastguard Worker       {{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1595*da0073e9SAndroid Build Coastguard Worker         {{4.0000, 5.0000}, {6.0000, 7.0000}},
1596*da0073e9SAndroid Build Coastguard Worker         {{8.0000, 9.0000}, {10.0000, 10.9999}},
1597*da0073e9SAndroid Build Coastguard Worker         {{11.9999, 12.9999}, {13.9999, 14.9999}},
1598*da0073e9SAndroid Build Coastguard Worker         {{15.9999, 16.9999}, {17.9999, 18.9999}}},
1599*da0073e9SAndroid Build Coastguard Worker        {{{19.9999, 20.9999}, {21.9999, 22.9999}},
1600*da0073e9SAndroid Build Coastguard Worker         {{23.9999, 24.9999}, {25.9999, 26.9999}},
1601*da0073e9SAndroid Build Coastguard Worker         {{27.9999, 28.9999}, {29.9998, 30.9998}},
1602*da0073e9SAndroid Build Coastguard Worker         {{31.9998, 32.9998}, {33.9998, 34.9998}},
1603*da0073e9SAndroid Build Coastguard Worker         {{35.9998, 36.9998}, {37.9998, 38.9998}}}});
1604*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1605*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1606*da0073e9SAndroid Build Coastguard Worker   s.backward();
1607*da0073e9SAndroid Build Coastguard Worker 
1608*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1609*da0073e9SAndroid Build Coastguard Worker }
1610*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm3dStateful)1611*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm3dStateful) {
1612*da0073e9SAndroid Build Coastguard Worker   BatchNorm3d bn(5);
1613*da0073e9SAndroid Build Coastguard Worker 
1614*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.track_running_stats());
1615*da0073e9SAndroid Build Coastguard Worker 
1616*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_mean.defined());
1617*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.dim(), 1);
1618*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_mean.size(0), 5);
1619*da0073e9SAndroid Build Coastguard Worker 
1620*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->running_var.defined());
1621*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.dim(), 1);
1622*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->running_var.size(0), 5);
1623*da0073e9SAndroid Build Coastguard Worker 
1624*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->num_batches_tracked.defined());
1625*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1626*da0073e9SAndroid Build Coastguard Worker 
1627*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->options.affine());
1628*da0073e9SAndroid Build Coastguard Worker 
1629*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->weight.defined());
1630*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.dim(), 1);
1631*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->weight.size(0), 5);
1632*da0073e9SAndroid Build Coastguard Worker 
1633*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(bn->bias.defined());
1634*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.dim(), 1);
1635*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(bn->bias.size(0), 5);
1636*da0073e9SAndroid Build Coastguard Worker }
1637*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm3dStateless)1638*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm3dStateless) {
1639*da0073e9SAndroid Build Coastguard Worker   BatchNorm3d bn(
1640*da0073e9SAndroid Build Coastguard Worker       BatchNorm3dOptions(5).track_running_stats(false).affine(false));
1641*da0073e9SAndroid Build Coastguard Worker 
1642*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_mean.defined());
1643*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->running_var.defined());
1644*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->num_batches_tracked.defined());
1645*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->weight.defined());
1646*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(bn->bias.defined());
1647*da0073e9SAndroid Build Coastguard Worker }
1648*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BatchNorm3d)1649*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BatchNorm3d) {
1650*da0073e9SAndroid Build Coastguard Worker   BatchNorm3d bn(5);
1651*da0073e9SAndroid Build Coastguard Worker   bn->eval();
1652*da0073e9SAndroid Build Coastguard Worker 
1653*da0073e9SAndroid Build Coastguard Worker   auto input =
1654*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1655*da0073e9SAndroid Build Coastguard Worker   auto output = bn->forward(input);
1656*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1657*da0073e9SAndroid Build Coastguard Worker       {{{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1658*da0073e9SAndroid Build Coastguard Worker          {{4.0000, 5.0000}, {6.0000, 7.0000}}},
1659*da0073e9SAndroid Build Coastguard Worker         {{{8.0000, 9.0000}, {10.0000, 10.9999}},
1660*da0073e9SAndroid Build Coastguard Worker          {{11.9999, 12.9999}, {13.9999, 14.9999}}},
1661*da0073e9SAndroid Build Coastguard Worker         {{{15.9999, 16.9999}, {17.9999, 18.9999}},
1662*da0073e9SAndroid Build Coastguard Worker          {{19.9999, 20.9999}, {21.9999, 22.9999}}},
1663*da0073e9SAndroid Build Coastguard Worker         {{{23.9999, 24.9999}, {25.9999, 26.9999}},
1664*da0073e9SAndroid Build Coastguard Worker          {{27.9999, 28.9999}, {29.9998, 30.9998}}},
1665*da0073e9SAndroid Build Coastguard Worker         {{{31.9998, 32.9998}, {33.9998, 34.9998}},
1666*da0073e9SAndroid Build Coastguard Worker          {{35.9998, 36.9998}, {37.9998, 38.9998}}}},
1667*da0073e9SAndroid Build Coastguard Worker        {{{{39.9998, 40.9998}, {41.9998, 42.9998}},
1668*da0073e9SAndroid Build Coastguard Worker          {{43.9998, 44.9998}, {45.9998, 46.9998}}},
1669*da0073e9SAndroid Build Coastguard Worker         {{{47.9998, 48.9998}, {49.9997, 50.9997}},
1670*da0073e9SAndroid Build Coastguard Worker          {{51.9997, 52.9997}, {53.9997, 54.9997}}},
1671*da0073e9SAndroid Build Coastguard Worker         {{{55.9997, 56.9997}, {57.9997, 58.9997}},
1672*da0073e9SAndroid Build Coastguard Worker          {{59.9997, 60.9997}, {61.9997, 62.9997}}},
1673*da0073e9SAndroid Build Coastguard Worker         {{{63.9997, 64.9997}, {65.9997, 66.9997}},
1674*da0073e9SAndroid Build Coastguard Worker          {{67.9997, 68.9997}, {69.9996, 70.9996}}},
1675*da0073e9SAndroid Build Coastguard Worker         {{{71.9996, 72.9996}, {73.9996, 74.9996}},
1676*da0073e9SAndroid Build Coastguard Worker          {{75.9996, 76.9996}, {77.9996, 78.9996}}}}});
1677*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1678*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1679*da0073e9SAndroid Build Coastguard Worker   s.backward();
1680*da0073e9SAndroid Build Coastguard Worker 
1681*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1682*da0073e9SAndroid Build Coastguard Worker }
1683*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm1dStateful)1684*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm1dStateful) {
1685*da0073e9SAndroid Build Coastguard Worker   InstanceNorm1d instance_norm(
1686*da0073e9SAndroid Build Coastguard Worker       InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
1687*da0073e9SAndroid Build Coastguard Worker 
1688*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.track_running_stats());
1689*da0073e9SAndroid Build Coastguard Worker 
1690*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_mean.defined());
1691*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1692*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1693*da0073e9SAndroid Build Coastguard Worker 
1694*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_var.defined());
1695*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1696*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1697*da0073e9SAndroid Build Coastguard Worker 
1698*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1699*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1700*da0073e9SAndroid Build Coastguard Worker 
1701*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.affine());
1702*da0073e9SAndroid Build Coastguard Worker 
1703*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->weight.defined());
1704*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.dim(), 1);
1705*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.size(0), 5);
1706*da0073e9SAndroid Build Coastguard Worker 
1707*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->bias.defined());
1708*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.dim(), 1);
1709*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.size(0), 5);
1710*da0073e9SAndroid Build Coastguard Worker }
1711*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm1dStateless)1712*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm1dStateless) {
1713*da0073e9SAndroid Build Coastguard Worker   InstanceNorm1d instance_norm(
1714*da0073e9SAndroid Build Coastguard Worker       InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
1715*da0073e9SAndroid Build Coastguard Worker 
1716*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_mean.defined());
1717*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_var.defined());
1718*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1719*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->weight.defined());
1720*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->bias.defined());
1721*da0073e9SAndroid Build Coastguard Worker }
1722*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm1d)1723*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm1d) {
1724*da0073e9SAndroid Build Coastguard Worker   InstanceNorm1d instance_norm(5);
1725*da0073e9SAndroid Build Coastguard Worker   instance_norm->eval();
1726*da0073e9SAndroid Build Coastguard Worker 
1727*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1728*da0073e9SAndroid Build Coastguard Worker   auto output = instance_norm->forward(input);
1729*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1730*da0073e9SAndroid Build Coastguard Worker       {{{-1.0000, 1.0000},
1731*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1732*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1733*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1734*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000}},
1735*da0073e9SAndroid Build Coastguard Worker        {{-1.0000, 1.0000},
1736*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1737*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1738*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000},
1739*da0073e9SAndroid Build Coastguard Worker         {-1.0000, 1.0000}}});
1740*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-3));
1741*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1742*da0073e9SAndroid Build Coastguard Worker   s.backward();
1743*da0073e9SAndroid Build Coastguard Worker 
1744*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1745*da0073e9SAndroid Build Coastguard Worker }
1746*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm2dStateful)1747*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm2dStateful) {
1748*da0073e9SAndroid Build Coastguard Worker   InstanceNorm2d instance_norm(
1749*da0073e9SAndroid Build Coastguard Worker       InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
1750*da0073e9SAndroid Build Coastguard Worker 
1751*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.track_running_stats());
1752*da0073e9SAndroid Build Coastguard Worker 
1753*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_mean.defined());
1754*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1755*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1756*da0073e9SAndroid Build Coastguard Worker 
1757*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_var.defined());
1758*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1759*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1760*da0073e9SAndroid Build Coastguard Worker 
1761*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1762*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1763*da0073e9SAndroid Build Coastguard Worker 
1764*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.affine());
1765*da0073e9SAndroid Build Coastguard Worker 
1766*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->weight.defined());
1767*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.dim(), 1);
1768*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.size(0), 5);
1769*da0073e9SAndroid Build Coastguard Worker 
1770*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->bias.defined());
1771*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.dim(), 1);
1772*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.size(0), 5);
1773*da0073e9SAndroid Build Coastguard Worker }
1774*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm2dStateless)1775*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm2dStateless) {
1776*da0073e9SAndroid Build Coastguard Worker   InstanceNorm2d instance_norm(
1777*da0073e9SAndroid Build Coastguard Worker       InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
1778*da0073e9SAndroid Build Coastguard Worker 
1779*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_mean.defined());
1780*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_var.defined());
1781*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1782*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->weight.defined());
1783*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->bias.defined());
1784*da0073e9SAndroid Build Coastguard Worker }
1785*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm2d)1786*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm2d) {
1787*da0073e9SAndroid Build Coastguard Worker   InstanceNorm2d instance_norm(5);
1788*da0073e9SAndroid Build Coastguard Worker   instance_norm->eval();
1789*da0073e9SAndroid Build Coastguard Worker 
1790*da0073e9SAndroid Build Coastguard Worker   auto input =
1791*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1792*da0073e9SAndroid Build Coastguard Worker   auto output = instance_norm->forward(input);
1793*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1794*da0073e9SAndroid Build Coastguard Worker       {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1795*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1796*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1797*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1798*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
1799*da0073e9SAndroid Build Coastguard Worker        {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1800*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1801*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1802*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1803*da0073e9SAndroid Build Coastguard Worker         {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
1804*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-3));
1805*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1806*da0073e9SAndroid Build Coastguard Worker   s.backward();
1807*da0073e9SAndroid Build Coastguard Worker 
1808*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1809*da0073e9SAndroid Build Coastguard Worker }
1810*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm3dStateful)1811*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm3dStateful) {
1812*da0073e9SAndroid Build Coastguard Worker   InstanceNorm3d instance_norm(
1813*da0073e9SAndroid Build Coastguard Worker       InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
1814*da0073e9SAndroid Build Coastguard Worker 
1815*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.track_running_stats());
1816*da0073e9SAndroid Build Coastguard Worker 
1817*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_mean.defined());
1818*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1819*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1820*da0073e9SAndroid Build Coastguard Worker 
1821*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->running_var.defined());
1822*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.dim(), 1);
1823*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->running_var.size(0), 5);
1824*da0073e9SAndroid Build Coastguard Worker 
1825*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1826*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1827*da0073e9SAndroid Build Coastguard Worker 
1828*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->options.affine());
1829*da0073e9SAndroid Build Coastguard Worker 
1830*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->weight.defined());
1831*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.dim(), 1);
1832*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->weight.size(0), 5);
1833*da0073e9SAndroid Build Coastguard Worker 
1834*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(instance_norm->bias.defined());
1835*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.dim(), 1);
1836*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(instance_norm->bias.size(0), 5);
1837*da0073e9SAndroid Build Coastguard Worker }
1838*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm3dStateless)1839*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm3dStateless) {
1840*da0073e9SAndroid Build Coastguard Worker   InstanceNorm3d instance_norm(
1841*da0073e9SAndroid Build Coastguard Worker       InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
1842*da0073e9SAndroid Build Coastguard Worker 
1843*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_mean.defined());
1844*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->running_var.defined());
1845*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1846*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->weight.defined());
1847*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(instance_norm->bias.defined());
1848*da0073e9SAndroid Build Coastguard Worker }
1849*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,InstanceNorm3d)1850*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, InstanceNorm3d) {
1851*da0073e9SAndroid Build Coastguard Worker   InstanceNorm3d instance_norm(5);
1852*da0073e9SAndroid Build Coastguard Worker   instance_norm->eval();
1853*da0073e9SAndroid Build Coastguard Worker 
1854*da0073e9SAndroid Build Coastguard Worker   auto input =
1855*da0073e9SAndroid Build Coastguard Worker       torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1856*da0073e9SAndroid Build Coastguard Worker   auto output = instance_norm->forward(input);
1857*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
1858*da0073e9SAndroid Build Coastguard Worker       {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1859*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1860*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1861*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1862*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1863*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1864*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1865*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1866*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1867*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
1868*da0073e9SAndroid Build Coastguard Worker        {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1869*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1870*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1871*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1872*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1873*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1874*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1875*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1876*da0073e9SAndroid Build Coastguard Worker         {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1877*da0073e9SAndroid Build Coastguard Worker          {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
1878*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-3));
1879*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1880*da0073e9SAndroid Build Coastguard Worker   s.backward();
1881*da0073e9SAndroid Build Coastguard Worker 
1882*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1883*da0073e9SAndroid Build Coastguard Worker }
1884*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Linear_CUDA)1885*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Linear_CUDA) {
1886*da0073e9SAndroid Build Coastguard Worker   Linear model(5, 2);
1887*da0073e9SAndroid Build Coastguard Worker   model->to(torch::kCUDA);
1888*da0073e9SAndroid Build Coastguard Worker   auto x =
1889*da0073e9SAndroid Build Coastguard Worker       torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
1890*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1891*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1892*da0073e9SAndroid Build Coastguard Worker 
1893*da0073e9SAndroid Build Coastguard Worker   s.backward();
1894*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1895*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1896*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 10);
1897*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 2);
1898*da0073e9SAndroid Build Coastguard Worker 
1899*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1900*da0073e9SAndroid Build Coastguard Worker }
1901*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Linear2_CUDA)1902*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Linear2_CUDA) {
1903*da0073e9SAndroid Build Coastguard Worker   Linear model(5, 2);
1904*da0073e9SAndroid Build Coastguard Worker   model->to(torch::kCUDA);
1905*da0073e9SAndroid Build Coastguard Worker   model->to(torch::kCPU);
1906*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({10, 5}, torch::requires_grad());
1907*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
1908*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
1909*da0073e9SAndroid Build Coastguard Worker 
1910*da0073e9SAndroid Build Coastguard Worker   s.backward();
1911*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
1912*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
1913*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 10);
1914*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 2);
1915*da0073e9SAndroid Build Coastguard Worker 
1916*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1917*da0073e9SAndroid Build Coastguard Worker }
1918*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,L1Loss)1919*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, L1Loss) {
1920*da0073e9SAndroid Build Coastguard Worker   L1Loss loss;
1921*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
1922*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
1923*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(torch::sigmoid(input), target);
1924*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1925*da0073e9SAndroid Build Coastguard Worker   s.backward();
1926*da0073e9SAndroid Build Coastguard Worker 
1927*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), std::vector<int64_t>());
1928*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1929*da0073e9SAndroid Build Coastguard Worker }
1930*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MSELoss)1931*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MSELoss) {
1932*da0073e9SAndroid Build Coastguard Worker   MSELoss loss;
1933*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
1934*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
1935*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(torch::sigmoid(input), target);
1936*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1937*da0073e9SAndroid Build Coastguard Worker   s.backward();
1938*da0073e9SAndroid Build Coastguard Worker 
1939*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1940*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1941*da0073e9SAndroid Build Coastguard Worker }
1942*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BCELoss)1943*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BCELoss) {
1944*da0073e9SAndroid Build Coastguard Worker   BCELoss loss;
1945*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
1946*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
1947*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(torch::sigmoid(input), target);
1948*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1949*da0073e9SAndroid Build Coastguard Worker   s.backward();
1950*da0073e9SAndroid Build Coastguard Worker 
1951*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1952*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1953*da0073e9SAndroid Build Coastguard Worker }
1954*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,KLDivLoss)1955*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, KLDivLoss) {
1956*da0073e9SAndroid Build Coastguard Worker   KLDivLoss loss;
1957*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({5, 6}, torch::requires_grad());
1958*da0073e9SAndroid Build Coastguard Worker   auto target = torch::empty({5, 6}).random_(2);
1959*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(torch::sigmoid(input), target);
1960*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1961*da0073e9SAndroid Build Coastguard Worker   s.backward();
1962*da0073e9SAndroid Build Coastguard Worker 
1963*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1964*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1965*da0073e9SAndroid Build Coastguard Worker }
1966*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,HingeEmbeddingLoss)1967*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, HingeEmbeddingLoss) {
1968*da0073e9SAndroid Build Coastguard Worker   HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2));
1969*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
1970*da0073e9SAndroid Build Coastguard Worker       {{2, 22, 4}, {20, 10, 0}},
1971*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
1972*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
1973*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
1974*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({10}, torch::kFloat);
1975*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1976*da0073e9SAndroid Build Coastguard Worker   s.backward();
1977*da0073e9SAndroid Build Coastguard Worker 
1978*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
1979*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1980*da0073e9SAndroid Build Coastguard Worker }
1981*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiMarginLoss)1982*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiMarginLoss) {
1983*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
1984*da0073e9SAndroid Build Coastguard Worker   MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight));
1985*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
1986*da0073e9SAndroid Build Coastguard Worker       {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
1987*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
1988*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({2, 1, 0}, torch::kLong);
1989*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
1990*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.305556}, torch::kFloat);
1991*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
1992*da0073e9SAndroid Build Coastguard Worker   s.backward();
1993*da0073e9SAndroid Build Coastguard Worker 
1994*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
1995*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
1996*da0073e9SAndroid Build Coastguard Worker }
1997*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CosineEmbeddingLoss)1998*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CosineEmbeddingLoss) {
1999*da0073e9SAndroid Build Coastguard Worker   CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5));
2000*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor(
2001*da0073e9SAndroid Build Coastguard Worker       {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
2002*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor(
2003*da0073e9SAndroid Build Coastguard Worker       {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true));
2004*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({1, -1});
2005*da0073e9SAndroid Build Coastguard Worker   auto output = cos(input1, input2, target);
2006*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.1004}, torch::kFloat);
2007*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2008*da0073e9SAndroid Build Coastguard Worker   s.backward();
2009*da0073e9SAndroid Build Coastguard Worker 
2010*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-4));
2011*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2012*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input2.sizes(), input2.grad().sizes());
2013*da0073e9SAndroid Build Coastguard Worker }
2014*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SmoothL1LossDefaultOptions)2015*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SmoothL1LossDefaultOptions) {
2016*da0073e9SAndroid Build Coastguard Worker   SmoothL1Loss loss;
2017*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2018*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2019*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2020*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2021*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.0233335, torch::kFloat);
2022*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2023*da0073e9SAndroid Build Coastguard Worker   s.backward();
2024*da0073e9SAndroid Build Coastguard Worker 
2025*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2026*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2027*da0073e9SAndroid Build Coastguard Worker }
2028*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,HuberLossDefaultOptions)2029*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, HuberLossDefaultOptions) {
2030*da0073e9SAndroid Build Coastguard Worker   HuberLoss loss;
2031*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2032*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2033*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2034*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2035*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.0233335, torch::kFloat);
2036*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2037*da0073e9SAndroid Build Coastguard Worker   s.backward();
2038*da0073e9SAndroid Build Coastguard Worker 
2039*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2040*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2041*da0073e9SAndroid Build Coastguard Worker }
2042*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiLabelMarginLossDefaultOptions)2043*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) {
2044*da0073e9SAndroid Build Coastguard Worker   MultiLabelMarginLoss loss;
2045*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2046*da0073e9SAndroid Build Coastguard Worker       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2047*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2048*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2049*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8500}, torch::kFloat);
2050*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2051*da0073e9SAndroid Build Coastguard Worker   s.backward();
2052*da0073e9SAndroid Build Coastguard Worker 
2053*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2054*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2055*da0073e9SAndroid Build Coastguard Worker }
2056*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SmoothL1LossNoReduction)2057*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SmoothL1LossNoReduction) {
2058*da0073e9SAndroid Build Coastguard Worker   SmoothL1Loss loss(/*reduction=*/torch::kNone);
2059*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2060*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2061*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2062*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2063*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2064*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2065*da0073e9SAndroid Build Coastguard Worker   s.backward();
2066*da0073e9SAndroid Build Coastguard Worker 
2067*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2068*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2069*da0073e9SAndroid Build Coastguard Worker }
2070*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,HuberLossNoReduction)2071*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, HuberLossNoReduction) {
2072*da0073e9SAndroid Build Coastguard Worker   HuberLoss loss(/*reduction=*/torch::kNone);
2073*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2074*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2075*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2076*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2077*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2078*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2079*da0073e9SAndroid Build Coastguard Worker   s.backward();
2080*da0073e9SAndroid Build Coastguard Worker 
2081*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2082*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2083*da0073e9SAndroid Build Coastguard Worker }
2084*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiLabelMarginLossNoReduction)2085*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) {
2086*da0073e9SAndroid Build Coastguard Worker   MultiLabelMarginLoss loss(torch::kNone);
2087*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2088*da0073e9SAndroid Build Coastguard Worker       {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2089*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2090*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2091*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8500}, torch::kFloat);
2092*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2093*da0073e9SAndroid Build Coastguard Worker   s.backward();
2094*da0073e9SAndroid Build Coastguard Worker 
2095*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2096*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2097*da0073e9SAndroid Build Coastguard Worker }
2098*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SmoothL1LossBeta)2099*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SmoothL1LossBeta) {
2100*da0073e9SAndroid Build Coastguard Worker   auto options = SmoothL1LossOptions().beta(0.2);
2101*da0073e9SAndroid Build Coastguard Worker   SmoothL1Loss loss(options);
2102*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2103*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2104*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2105*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2106*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.108333, torch::kFloat);
2107*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2108*da0073e9SAndroid Build Coastguard Worker   s.backward();
2109*da0073e9SAndroid Build Coastguard Worker 
2110*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2111*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2112*da0073e9SAndroid Build Coastguard Worker }
2113*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,HuberLossDelta)2114*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, HuberLossDelta) {
2115*da0073e9SAndroid Build Coastguard Worker   auto options = HuberLossOptions().delta(0.2);
2116*da0073e9SAndroid Build Coastguard Worker   HuberLoss loss(options);
2117*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2118*da0073e9SAndroid Build Coastguard Worker       {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2119*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2120*da0073e9SAndroid Build Coastguard Worker   auto output = loss(input, target);
2121*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.0216666, torch::kFloat);
2122*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2123*da0073e9SAndroid Build Coastguard Worker   s.backward();
2124*da0073e9SAndroid Build Coastguard Worker 
2125*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2126*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2127*da0073e9SAndroid Build Coastguard Worker }
2128*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,TripletMarginLoss)2129*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, TripletMarginLoss) {
2130*da0073e9SAndroid Build Coastguard Worker   TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0));
2131*da0073e9SAndroid Build Coastguard Worker   auto anchor = torch::tensor(
2132*da0073e9SAndroid Build Coastguard Worker       {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true));
2133*da0073e9SAndroid Build Coastguard Worker   auto positive = torch::tensor(
2134*da0073e9SAndroid Build Coastguard Worker       {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2135*da0073e9SAndroid Build Coastguard Worker   auto negative = torch::tensor(
2136*da0073e9SAndroid Build Coastguard Worker       {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true));
2137*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(anchor, positive, negative);
2138*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.}, torch::kFloat);
2139*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2140*da0073e9SAndroid Build Coastguard Worker   s.backward();
2141*da0073e9SAndroid Build Coastguard Worker 
2142*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2143*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2144*da0073e9SAndroid Build Coastguard Worker }
2145*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,TripletMarginWithDistanceLossDefaultParity)2146*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
2147*da0073e9SAndroid Build Coastguard Worker   // Check that if we use torch::pairwise_distance with the default
2148*da0073e9SAndroid Build Coastguard Worker   // TripletMarginLoss options as our distance function, the outputs
2149*da0073e9SAndroid Build Coastguard Worker   // are equal (i.e., equal under defaults).
2150*da0073e9SAndroid Build Coastguard Worker 
2151*da0073e9SAndroid Build Coastguard Worker   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2152*da0073e9SAndroid Build Coastguard Worker       torch::kSum, torch::kMean, torch::kNone};
2153*da0073e9SAndroid Build Coastguard Worker   std::vector<float> margins = {0.5, 1.0, 1.5};
2154*da0073e9SAndroid Build Coastguard Worker   std::vector<bool> swaps = {true, false};
2155*da0073e9SAndroid Build Coastguard Worker 
2156*da0073e9SAndroid Build Coastguard Worker   for (auto& reduction : reductions) {
2157*da0073e9SAndroid Build Coastguard Worker     for (auto& margin : margins) {
2158*da0073e9SAndroid Build Coastguard Worker       for (const auto swap : swaps) {
2159*da0073e9SAndroid Build Coastguard Worker         auto anchor = torch::randn(
2160*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2161*da0073e9SAndroid Build Coastguard Worker         auto positive = torch::randn(
2162*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2163*da0073e9SAndroid Build Coastguard Worker         auto negative = torch::randn(
2164*da0073e9SAndroid Build Coastguard Worker             {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2165*da0073e9SAndroid Build Coastguard Worker 
2166*da0073e9SAndroid Build Coastguard Worker         auto basicOptions =
2167*da0073e9SAndroid Build Coastguard Worker             TripletMarginLossOptions().reduction(reduction).margin(margin).swap(
2168*da0073e9SAndroid Build Coastguard Worker                 swap);
2169*da0073e9SAndroid Build Coastguard Worker         auto distanceOptions = TripletMarginWithDistanceLossOptions()
2170*da0073e9SAndroid Build Coastguard Worker                                    .reduction(reduction)
2171*da0073e9SAndroid Build Coastguard Worker                                    .margin(margin)
2172*da0073e9SAndroid Build Coastguard Worker                                    .swap(swap);
2173*da0073e9SAndroid Build Coastguard Worker         TripletMarginLoss basicLoss(basicOptions);
2174*da0073e9SAndroid Build Coastguard Worker         TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
2175*da0073e9SAndroid Build Coastguard Worker 
2176*da0073e9SAndroid Build Coastguard Worker         auto basicOutput = basicLoss->forward(anchor, positive, negative);
2177*da0073e9SAndroid Build Coastguard Worker         auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
2178*da0073e9SAndroid Build Coastguard Worker         auto basicOperatorOutput = basicLoss(anchor, positive, negative);
2179*da0073e9SAndroid Build Coastguard Worker         auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);
2180*da0073e9SAndroid Build Coastguard Worker 
2181*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
2182*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(
2183*da0073e9SAndroid Build Coastguard Worker             distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
2184*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(
2185*da0073e9SAndroid Build Coastguard Worker             distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));
2186*da0073e9SAndroid Build Coastguard Worker 
2187*da0073e9SAndroid Build Coastguard Worker         // handle for torch::kNone reduction
2188*da0073e9SAndroid Build Coastguard Worker         auto sum = distanceOutput.sum();
2189*da0073e9SAndroid Build Coastguard Worker         sum.backward();
2190*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2191*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(positive.sizes(), positive.grad().sizes());
2192*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(negative.sizes(), negative.grad().sizes());
2193*da0073e9SAndroid Build Coastguard Worker       }
2194*da0073e9SAndroid Build Coastguard Worker     }
2195*da0073e9SAndroid Build Coastguard Worker   }
2196*da0073e9SAndroid Build Coastguard Worker }
2197*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,TripletMarginWithDistanceLossFunctionalParity)2198*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
2199*da0073e9SAndroid Build Coastguard Worker   // Check for parity between F::triplet_margin_with_distance_loss and
2200*da0073e9SAndroid Build Coastguard Worker   // TripletMarginWithDistanceLoss.
2201*da0073e9SAndroid Build Coastguard Worker   auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2202*da0073e9SAndroid Build Coastguard Worker     return torch::pairwise_distance(x, y);
2203*da0073e9SAndroid Build Coastguard Worker   };
2204*da0073e9SAndroid Build Coastguard Worker   auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2205*da0073e9SAndroid Build Coastguard Worker     return 1.0 - torch::cosine_similarity(x, y);
2206*da0073e9SAndroid Build Coastguard Worker   };
2207*da0073e9SAndroid Build Coastguard Worker   std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
2208*da0073e9SAndroid Build Coastguard Worker       distance_functions = {pairwise_distance, cosine_distance};
2209*da0073e9SAndroid Build Coastguard Worker 
2210*da0073e9SAndroid Build Coastguard Worker   std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2211*da0073e9SAndroid Build Coastguard Worker       torch::kSum, torch::kMean, torch::kNone};
2212*da0073e9SAndroid Build Coastguard Worker   std::vector<float> margins = {0.5, 1.0, 1.5};
2213*da0073e9SAndroid Build Coastguard Worker   std::vector<bool> swaps = {true, false};
2214*da0073e9SAndroid Build Coastguard Worker 
2215*da0073e9SAndroid Build Coastguard Worker   for (auto& function : distance_functions) {
2216*da0073e9SAndroid Build Coastguard Worker     for (auto& reduction : reductions) {
2217*da0073e9SAndroid Build Coastguard Worker       for (auto& margin : margins) {
2218*da0073e9SAndroid Build Coastguard Worker         for (const auto swap : swaps) {
2219*da0073e9SAndroid Build Coastguard Worker           auto moduleOptions = TripletMarginWithDistanceLossOptions()
2220*da0073e9SAndroid Build Coastguard Worker                                    .distance_function(function)
2221*da0073e9SAndroid Build Coastguard Worker                                    .reduction(reduction)
2222*da0073e9SAndroid Build Coastguard Worker                                    .margin(margin)
2223*da0073e9SAndroid Build Coastguard Worker                                    .swap(swap);
2224*da0073e9SAndroid Build Coastguard Worker           auto functionOptions =
2225*da0073e9SAndroid Build Coastguard Worker               torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
2226*da0073e9SAndroid Build Coastguard Worker                   .distance_function(function)
2227*da0073e9SAndroid Build Coastguard Worker                   .reduction(reduction)
2228*da0073e9SAndroid Build Coastguard Worker                   .margin(margin)
2229*da0073e9SAndroid Build Coastguard Worker                   .swap(swap);
2230*da0073e9SAndroid Build Coastguard Worker 
2231*da0073e9SAndroid Build Coastguard Worker           auto anchor = torch::randn(
2232*da0073e9SAndroid Build Coastguard Worker               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2233*da0073e9SAndroid Build Coastguard Worker           auto positive = torch::randn(
2234*da0073e9SAndroid Build Coastguard Worker               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2235*da0073e9SAndroid Build Coastguard Worker           auto negative = torch::randn(
2236*da0073e9SAndroid Build Coastguard Worker               {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2237*da0073e9SAndroid Build Coastguard Worker 
2238*da0073e9SAndroid Build Coastguard Worker           TripletMarginWithDistanceLoss distanceLoss(moduleOptions);
2239*da0073e9SAndroid Build Coastguard Worker 
2240*da0073e9SAndroid Build Coastguard Worker           auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
2241*da0073e9SAndroid Build Coastguard Worker           auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
2242*da0073e9SAndroid Build Coastguard Worker           auto functionOutput =
2243*da0073e9SAndroid Build Coastguard Worker               torch::nn::functional::triplet_margin_with_distance_loss(
2244*da0073e9SAndroid Build Coastguard Worker                   anchor, positive, negative, functionOptions);
2245*da0073e9SAndroid Build Coastguard Worker 
2246*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
2247*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(
2248*da0073e9SAndroid Build Coastguard Worker               moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
2249*da0073e9SAndroid Build Coastguard Worker         }
2250*da0073e9SAndroid Build Coastguard Worker       }
2251*da0073e9SAndroid Build Coastguard Worker     }
2252*da0073e9SAndroid Build Coastguard Worker   }
2253*da0073e9SAndroid Build Coastguard Worker }
2254*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,NLLLoss)2255*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, NLLLoss) {
2256*da0073e9SAndroid Build Coastguard Worker   NLLLoss loss;
2257*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2258*da0073e9SAndroid Build Coastguard Worker       {{-0.1315, -3.1315, -2.5315},
2259*da0073e9SAndroid Build Coastguard Worker        {-3.7038, -0.1038, -2.6038},
2260*da0073e9SAndroid Build Coastguard Worker        {-2.3422, -1.3422, -0.4422}},
2261*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
2262*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({1, 0, 2}, torch::kLong);
2263*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2264*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(2.4258, torch::kFloat);
2265*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2266*da0073e9SAndroid Build Coastguard Worker   s.backward();
2267*da0073e9SAndroid Build Coastguard Worker 
2268*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2269*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2270*da0073e9SAndroid Build Coastguard Worker       NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean))
2271*da0073e9SAndroid Build Coastguard Worker           ->forward(input, target)
2272*da0073e9SAndroid Build Coastguard Worker           .allclose(expected, 1e-04));
2273*da0073e9SAndroid Build Coastguard Worker }
2274*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CrossEntropyLoss)2275*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CrossEntropyLoss) {
2276*da0073e9SAndroid Build Coastguard Worker   CrossEntropyLoss loss;
2277*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2278*da0073e9SAndroid Build Coastguard Worker       {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2279*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({0, 1}, torch::kLong);
2280*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2281*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(0.6931, torch::kFloat);
2282*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2283*da0073e9SAndroid Build Coastguard Worker   s.backward();
2284*da0073e9SAndroid Build Coastguard Worker 
2285*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2286*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2287*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2288*da0073e9SAndroid Build Coastguard Worker       CrossEntropyLoss(
2289*da0073e9SAndroid Build Coastguard Worker           CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean))
2290*da0073e9SAndroid Build Coastguard Worker           ->forward(input, target)
2291*da0073e9SAndroid Build Coastguard Worker           .allclose(expected, 1e-04));
2292*da0073e9SAndroid Build Coastguard Worker 
2293*da0073e9SAndroid Build Coastguard Worker   // label smoothing with class indices
2294*da0073e9SAndroid Build Coastguard Worker   loss = CrossEntropyLoss(
2295*da0073e9SAndroid Build Coastguard Worker       CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean));
2296*da0073e9SAndroid Build Coastguard Worker   input = torch::tensor(
2297*da0073e9SAndroid Build Coastguard Worker       {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2298*da0073e9SAndroid Build Coastguard Worker   target = torch::tensor({0, 1}, torch::kLong);
2299*da0073e9SAndroid Build Coastguard Worker   output = loss->forward(input, target);
2300*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(0.3326, torch::kFloat);
2301*da0073e9SAndroid Build Coastguard Worker   s = output.sum();
2302*da0073e9SAndroid Build Coastguard Worker   s.backward();
2303*da0073e9SAndroid Build Coastguard Worker 
2304*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2305*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2306*da0073e9SAndroid Build Coastguard Worker 
2307*da0073e9SAndroid Build Coastguard Worker   // label smoothing with with target probabilities
2308*da0073e9SAndroid Build Coastguard Worker   loss = CrossEntropyLoss(
2309*da0073e9SAndroid Build Coastguard Worker       CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean));
2310*da0073e9SAndroid Build Coastguard Worker   input = torch::tensor(
2311*da0073e9SAndroid Build Coastguard Worker       {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2312*da0073e9SAndroid Build Coastguard Worker   target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
2313*da0073e9SAndroid Build Coastguard Worker   output = loss->forward(input, target);
2314*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(0.5701, torch::kFloat);
2315*da0073e9SAndroid Build Coastguard Worker   s = output.sum();
2316*da0073e9SAndroid Build Coastguard Worker   s.backward();
2317*da0073e9SAndroid Build Coastguard Worker 
2318*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2319*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2320*da0073e9SAndroid Build Coastguard Worker }
2321*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CosineSimilarity)2322*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CosineSimilarity) {
2323*da0073e9SAndroid Build Coastguard Worker   CosineSimilarity cos(CosineSimilarityOptions().dim(1));
2324*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor(
2325*da0073e9SAndroid Build Coastguard Worker       {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2326*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor(
2327*da0073e9SAndroid Build Coastguard Worker       {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2328*da0073e9SAndroid Build Coastguard Worker   auto output = cos->forward(input1, input2);
2329*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
2330*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2331*da0073e9SAndroid Build Coastguard Worker   s.backward();
2332*da0073e9SAndroid Build Coastguard Worker 
2333*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected, 1e-04));
2334*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2335*da0073e9SAndroid Build Coastguard Worker }
2336*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SoftMarginLossDefaultOptions)2337*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SoftMarginLossDefaultOptions) {
2338*da0073e9SAndroid Build Coastguard Worker   SoftMarginLoss loss;
2339*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2340*da0073e9SAndroid Build Coastguard Worker       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2341*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2342*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2343*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({1.3767317}, torch::kFloat);
2344*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2345*da0073e9SAndroid Build Coastguard Worker   s.backward();
2346*da0073e9SAndroid Build Coastguard Worker 
2347*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2348*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2349*da0073e9SAndroid Build Coastguard Worker }
2350*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiLabelSoftMarginLossDefaultOptions)2351*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
2352*da0073e9SAndroid Build Coastguard Worker   MultiLabelSoftMarginLoss loss;
2353*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2354*da0073e9SAndroid Build Coastguard Worker       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2355*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
2356*da0073e9SAndroid Build Coastguard Worker   auto target =
2357*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2358*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2359*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.7608436}, torch::kFloat);
2360*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2361*da0073e9SAndroid Build Coastguard Worker   s.backward();
2362*da0073e9SAndroid Build Coastguard Worker 
2363*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2364*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2365*da0073e9SAndroid Build Coastguard Worker }
2366*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SoftMarginLossNoReduction)2367*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SoftMarginLossNoReduction) {
2368*da0073e9SAndroid Build Coastguard Worker   SoftMarginLoss loss(torch::kNone);
2369*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2370*da0073e9SAndroid Build Coastguard Worker       {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2371*da0073e9SAndroid Build Coastguard Worker   auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2372*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2373*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
2374*da0073e9SAndroid Build Coastguard Worker       {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
2375*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2376*da0073e9SAndroid Build Coastguard Worker   s.backward();
2377*da0073e9SAndroid Build Coastguard Worker 
2378*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2379*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2380*da0073e9SAndroid Build Coastguard Worker }
2381*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiLabelSoftMarginLossWeightedNoReduction)2382*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
2383*da0073e9SAndroid Build Coastguard Worker   auto input = torch::tensor(
2384*da0073e9SAndroid Build Coastguard Worker       {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2385*da0073e9SAndroid Build Coastguard Worker       torch::dtype(torch::kFloat).requires_grad(true));
2386*da0073e9SAndroid Build Coastguard Worker   auto target =
2387*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2388*da0073e9SAndroid Build Coastguard Worker   auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
2389*da0073e9SAndroid Build Coastguard Worker   auto options =
2390*da0073e9SAndroid Build Coastguard Worker       MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight);
2391*da0073e9SAndroid Build Coastguard Worker   MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options);
2392*da0073e9SAndroid Build Coastguard Worker   auto output = loss->forward(input, target);
2393*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
2394*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2395*da0073e9SAndroid Build Coastguard Worker   s.backward();
2396*da0073e9SAndroid Build Coastguard Worker 
2397*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2398*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.sizes(), input.grad().sizes());
2399*da0073e9SAndroid Build Coastguard Worker }
2400*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PairwiseDistance)2401*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PairwiseDistance) {
2402*da0073e9SAndroid Build Coastguard Worker   PairwiseDistance dist(PairwiseDistanceOptions().p(1));
2403*da0073e9SAndroid Build Coastguard Worker   auto input1 = torch::tensor(
2404*da0073e9SAndroid Build Coastguard Worker       {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2405*da0073e9SAndroid Build Coastguard Worker   auto input2 = torch::tensor(
2406*da0073e9SAndroid Build Coastguard Worker       {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2407*da0073e9SAndroid Build Coastguard Worker   auto output = dist->forward(input1, input2);
2408*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor({6, 6}, torch::kFloat);
2409*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2410*da0073e9SAndroid Build Coastguard Worker   s.backward();
2411*da0073e9SAndroid Build Coastguard Worker 
2412*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2413*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2414*da0073e9SAndroid Build Coastguard Worker }
2415*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ELU)2416*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ELU) {
2417*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2418*da0073e9SAndroid Build Coastguard Worker   for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2419*da0073e9SAndroid Build Coastguard Worker     for (const auto inplace : {false, true}) {
2420*da0073e9SAndroid Build Coastguard Worker       ELU model{ELUOptions().alpha(alpha).inplace(inplace)};
2421*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-10.0, 10.0, size * size * size);
2422*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
2423*da0073e9SAndroid Build Coastguard Worker       if (!inplace) {
2424*da0073e9SAndroid Build Coastguard Worker         x.requires_grad_(true);
2425*da0073e9SAndroid Build Coastguard Worker       }
2426*da0073e9SAndroid Build Coastguard Worker       auto x_orig = x.clone();
2427*da0073e9SAndroid Build Coastguard Worker       auto y = model(x);
2428*da0073e9SAndroid Build Coastguard Worker       torch::Tensor s = y.sum();
2429*da0073e9SAndroid Build Coastguard Worker 
2430*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(s.ndimension(), 0);
2431*da0073e9SAndroid Build Coastguard Worker 
2432*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
2433*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2434*da0073e9SAndroid Build Coastguard Worker       auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2435*da0073e9SAndroid Build Coastguard Worker           torch::min(torch::zeros_like(x_orig),
2436*da0073e9SAndroid Build Coastguard Worker                      alpha * (torch::exp(x_orig) - 1.0));
2437*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
2438*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
2439*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x, y_exp));
2440*da0073e9SAndroid Build Coastguard Worker       } else {
2441*da0073e9SAndroid Build Coastguard Worker         s.backward();
2442*da0073e9SAndroid Build Coastguard Worker       }
2443*da0073e9SAndroid Build Coastguard Worker     }
2444*da0073e9SAndroid Build Coastguard Worker   }
2445*da0073e9SAndroid Build Coastguard Worker }
2446*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,SELU)2447*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, SELU) {
2448*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
2449*da0073e9SAndroid Build Coastguard Worker     SELU model(inplace);
2450*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({5, 5});
2451*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
2452*da0073e9SAndroid Build Coastguard Worker       input.requires_grad_(true);
2453*da0073e9SAndroid Build Coastguard Worker     }
2454*da0073e9SAndroid Build Coastguard Worker     auto input_orig = input.clone();
2455*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
2456*da0073e9SAndroid Build Coastguard Worker     const double scale = 1.0507009873554804934193349852946;
2457*da0073e9SAndroid Build Coastguard Worker     const double alpha = 1.6732632423543772848170429916717;
2458*da0073e9SAndroid Build Coastguard Worker     auto zero = torch::zeros_like(input);
2459*da0073e9SAndroid Build Coastguard Worker     auto expected = scale *
2460*da0073e9SAndroid Build Coastguard Worker         (torch::max(zero, input_orig) +
2461*da0073e9SAndroid Build Coastguard Worker          torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
2462*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
2463*da0073e9SAndroid Build Coastguard Worker 
2464*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
2465*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
2466*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
2467*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(input.allclose(expected));
2468*da0073e9SAndroid Build Coastguard Worker     } else {
2469*da0073e9SAndroid Build Coastguard Worker       s.backward();
2470*da0073e9SAndroid Build Coastguard Worker     }
2471*da0073e9SAndroid Build Coastguard Worker   }
2472*da0073e9SAndroid Build Coastguard Worker }
2473*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Hardshrink)2474*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Hardshrink) {
2475*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2476*da0073e9SAndroid Build Coastguard Worker   for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
2477*da0073e9SAndroid Build Coastguard Worker     Hardshrink model{HardshrinkOptions().lambda(lambda)};
2478*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2479*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size}).set_requires_grad(true);
2480*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
2481*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
2482*da0073e9SAndroid Build Coastguard Worker 
2483*da0073e9SAndroid Build Coastguard Worker     s.backward();
2484*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
2485*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
2486*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2487*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x.abs() > lambda) * x;
2488*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
2489*da0073e9SAndroid Build Coastguard Worker   }
2490*da0073e9SAndroid Build Coastguard Worker }
2491*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Hardtanh)2492*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Hardtanh) {
2493*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2494*da0073e9SAndroid Build Coastguard Worker   for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
2495*da0073e9SAndroid Build Coastguard Worker     for (const auto max_val : {0.42, 1.0, 4.2}) {
2496*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
2497*da0073e9SAndroid Build Coastguard Worker         Hardtanh model{
2498*da0073e9SAndroid Build Coastguard Worker             HardtanhOptions().min_val(min_val).max_val(max_val).inplace(
2499*da0073e9SAndroid Build Coastguard Worker                 inplace)};
2500*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-10.0, 10.0, size * size * size);
2501*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
2502*da0073e9SAndroid Build Coastguard Worker         if (!inplace) {
2503*da0073e9SAndroid Build Coastguard Worker           x.requires_grad_(true);
2504*da0073e9SAndroid Build Coastguard Worker         }
2505*da0073e9SAndroid Build Coastguard Worker         auto x_orig = x.clone();
2506*da0073e9SAndroid Build Coastguard Worker         auto y = model(x);
2507*da0073e9SAndroid Build Coastguard Worker         torch::Tensor s = y.sum();
2508*da0073e9SAndroid Build Coastguard Worker 
2509*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
2510*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
2511*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2512*da0073e9SAndroid Build Coastguard Worker         auto y_exp = (x_orig < min_val) * min_val +
2513*da0073e9SAndroid Build Coastguard Worker             ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig +
2514*da0073e9SAndroid Build Coastguard Worker             (x_orig > max_val) * max_val;
2515*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
2516*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
2517*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
2518*da0073e9SAndroid Build Coastguard Worker         } else {
2519*da0073e9SAndroid Build Coastguard Worker           s.backward();
2520*da0073e9SAndroid Build Coastguard Worker         }
2521*da0073e9SAndroid Build Coastguard Worker       }
2522*da0073e9SAndroid Build Coastguard Worker     }
2523*da0073e9SAndroid Build Coastguard Worker   }
2524*da0073e9SAndroid Build Coastguard Worker }
2525*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,HardtanhMinValGEMaxVal)2526*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, HardtanhMinValGEMaxVal) {
2527*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
2528*da0073e9SAndroid Build Coastguard Worker       Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)},
2529*da0073e9SAndroid Build Coastguard Worker       "max_val must be greater than min_val");
2530*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
2531*da0073e9SAndroid Build Coastguard Worker       Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)},
2532*da0073e9SAndroid Build Coastguard Worker       "max_val must be greater than min_val");
2533*da0073e9SAndroid Build Coastguard Worker 
2534*da0073e9SAndroid Build Coastguard Worker   Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)};
2535*da0073e9SAndroid Build Coastguard Worker   ht->options.min_val(0.42);
2536*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2537*da0073e9SAndroid Build Coastguard Worker   ht->options.max_val(-0.42);
2538*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2539*da0073e9SAndroid Build Coastguard Worker }
2540*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LeakyReLU)2541*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LeakyReLU) {
2542*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2543*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
2544*da0073e9SAndroid Build Coastguard Worker     for (const auto negative_slope : {0.0, 0.42, 1.0}) {
2545*da0073e9SAndroid Build Coastguard Worker       for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2546*da0073e9SAndroid Build Coastguard Worker         LeakyReLU model{
2547*da0073e9SAndroid Build Coastguard Worker             LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
2548*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2549*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
2550*da0073e9SAndroid Build Coastguard Worker         if (!inplace) {
2551*da0073e9SAndroid Build Coastguard Worker           x.requires_grad_(true);
2552*da0073e9SAndroid Build Coastguard Worker         }
2553*da0073e9SAndroid Build Coastguard Worker         auto x_orig = x.clone();
2554*da0073e9SAndroid Build Coastguard Worker         auto y = model(x);
2555*da0073e9SAndroid Build Coastguard Worker         torch::Tensor s = y.sum();
2556*da0073e9SAndroid Build Coastguard Worker 
2557*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
2558*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
2559*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2560*da0073e9SAndroid Build Coastguard Worker         auto y_exp =
2561*da0073e9SAndroid Build Coastguard Worker             (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
2562*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
2563*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
2564*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
2565*da0073e9SAndroid Build Coastguard Worker         } else {
2566*da0073e9SAndroid Build Coastguard Worker           s.backward();
2567*da0073e9SAndroid Build Coastguard Worker         }
2568*da0073e9SAndroid Build Coastguard Worker       }
2569*da0073e9SAndroid Build Coastguard Worker     }
2570*da0073e9SAndroid Build Coastguard Worker   }
2571*da0073e9SAndroid Build Coastguard Worker }
2572*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LogSigmoid)2573*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LogSigmoid) {
2574*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2575*da0073e9SAndroid Build Coastguard Worker   LogSigmoid model;
2576*da0073e9SAndroid Build Coastguard Worker   auto x = torch::linspace(-10.0, 10.0, size * size * size);
2577*da0073e9SAndroid Build Coastguard Worker   x.resize_({size, size, size}).set_requires_grad(true);
2578*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
2579*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
2580*da0073e9SAndroid Build Coastguard Worker 
2581*da0073e9SAndroid Build Coastguard Worker   s.backward();
2582*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
2583*da0073e9SAndroid Build Coastguard Worker 
2584*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 3);
2585*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2586*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::log(
2587*da0073e9SAndroid Build Coastguard Worker       torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
2588*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
2589*da0073e9SAndroid Build Coastguard Worker }
2590*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softmax)2591*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softmax) {
2592*da0073e9SAndroid Build Coastguard Worker   Softmax m(/*dim=*/1);
2593*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2594*da0073e9SAndroid Build Coastguard Worker   auto output = m(input);
2595*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(input), 1);
2596*da0073e9SAndroid Build Coastguard Worker 
2597*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
2598*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::exp(input[i]) / sum[i];
2599*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
2600*da0073e9SAndroid Build Coastguard Worker   }
2601*da0073e9SAndroid Build Coastguard Worker }
2602*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softmin)2603*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softmin) {
2604*da0073e9SAndroid Build Coastguard Worker   Softmin m(/*dim=*/1);
2605*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2606*da0073e9SAndroid Build Coastguard Worker   auto output = m(input);
2607*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(-input), 1);
2608*da0073e9SAndroid Build Coastguard Worker 
2609*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
2610*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::exp(-input[i]) / sum[i];
2611*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
2612*da0073e9SAndroid Build Coastguard Worker   }
2613*da0073e9SAndroid Build Coastguard Worker }
2614*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LogSoftmax)2615*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LogSoftmax) {
2616*da0073e9SAndroid Build Coastguard Worker   LogSoftmax m(/*dim=*/1);
2617*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2618*da0073e9SAndroid Build Coastguard Worker   auto output = m(input);
2619*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(input), 1);
2620*da0073e9SAndroid Build Coastguard Worker 
2621*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(2)) {
2622*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::log(torch::exp(input[i]) / sum[i]);
2623*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output[i], expected));
2624*da0073e9SAndroid Build Coastguard Worker   }
2625*da0073e9SAndroid Build Coastguard Worker }
2626*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,AdaptiveLogSoftmaxWithLoss)2627*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
2628*da0073e9SAndroid Build Coastguard Worker   {
2629*da0073e9SAndroid Build Coastguard Worker     // log_probs actually returns log_proba
2630*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
2631*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2632*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({4, 8});
2633*da0073e9SAndroid Build Coastguard Worker     auto logprob_out = asfm->log_prob(x);
2634*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
2635*da0073e9SAndroid Build Coastguard Worker         torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
2636*da0073e9SAndroid Build Coastguard Worker   }
2637*da0073e9SAndroid Build Coastguard Worker   {
2638*da0073e9SAndroid Build Coastguard Worker     // test predict
2639*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
2640*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
2641*da0073e9SAndroid Build Coastguard Worker             .div_value(2.)
2642*da0073e9SAndroid Build Coastguard Worker             .head_bias(true));
2643*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({64, 8});
2644*da0073e9SAndroid Build Coastguard Worker     auto logprob_out = asfm->log_prob(x);
2645*da0073e9SAndroid Build Coastguard Worker     auto predict_out = asfm->predict(x);
2646*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
2647*da0073e9SAndroid Build Coastguard Worker   }
2648*da0073e9SAndroid Build Coastguard Worker   {
2649*da0073e9SAndroid Build Coastguard Worker     // cluster sizes
2650*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
2651*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2652*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
2653*da0073e9SAndroid Build Coastguard Worker     auto y = torch::tensor({0, 17}, torch::kLong);
2654*da0073e9SAndroid Build Coastguard Worker     auto asm_out = asfm(x, y);
2655*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
2656*da0073e9SAndroid Build Coastguard Worker   }
2657*da0073e9SAndroid Build Coastguard Worker   {
2658*da0073e9SAndroid Build Coastguard Worker     // forward returns the same thing as log_probs
2659*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
2660*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2661*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({4, 8});
2662*da0073e9SAndroid Build Coastguard Worker     auto logprob_out = asfm->log_prob(x);
2663*da0073e9SAndroid Build Coastguard Worker     NLLLoss nll_loss;
2664*da0073e9SAndroid Build Coastguard Worker 
2665*da0073e9SAndroid Build Coastguard Worker     for (const auto v : c10::irange(4)) {
2666*da0073e9SAndroid Build Coastguard Worker       auto y = torch::full({4}, v, torch::kLong);
2667*da0073e9SAndroid Build Coastguard Worker       auto asm_out = asfm(x, y);
2668*da0073e9SAndroid Build Coastguard Worker       auto out = asm_out.output;
2669*da0073e9SAndroid Build Coastguard Worker       auto loss = torch::tensor(asm_out.loss, torch::kFloat);
2670*da0073e9SAndroid Build Coastguard Worker       auto expected = nll_loss->forward(logprob_out, y);
2671*da0073e9SAndroid Build Coastguard Worker 
2672*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(loss, expected));
2673*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(
2674*da0073e9SAndroid Build Coastguard Worker           out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
2675*da0073e9SAndroid Build Coastguard Worker     }
2676*da0073e9SAndroid Build Coastguard Worker   }
2677*da0073e9SAndroid Build Coastguard Worker   {
2678*da0073e9SAndroid Build Coastguard Worker     // test no batch dim
2679*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
2680*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2681*da0073e9SAndroid Build Coastguard Worker     auto x = torch::randn({1, 16});
2682*da0073e9SAndroid Build Coastguard Worker     auto y = torch::tensor({17});
2683*da0073e9SAndroid Build Coastguard Worker     auto x2 = x.squeeze(0);
2684*da0073e9SAndroid Build Coastguard Worker     auto y2 = y.squeeze(0);
2685*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
2686*da0073e9SAndroid Build Coastguard Worker         torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
2687*da0073e9SAndroid Build Coastguard Worker   }
2688*da0073e9SAndroid Build Coastguard Worker   {
2689*da0073e9SAndroid Build Coastguard Worker     // test div_value
2690*da0073e9SAndroid Build Coastguard Worker     auto options =
2691*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
2692*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
2693*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLoss(options),
2694*da0073e9SAndroid Build Coastguard Worker         "div_value should not be equal to 0");
2695*da0073e9SAndroid Build Coastguard Worker 
2696*da0073e9SAndroid Build Coastguard Worker     options =
2697*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
2698*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
2699*da0073e9SAndroid Build Coastguard Worker   }
2700*da0073e9SAndroid Build Coastguard Worker }
2701*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softmax2d)2702*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softmax2d) {
2703*da0073e9SAndroid Build Coastguard Worker   Softmax2d m;
2704*da0073e9SAndroid Build Coastguard Worker   auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
2705*da0073e9SAndroid Build Coastguard Worker   auto output = m(input);
2706*da0073e9SAndroid Build Coastguard Worker   auto sum = torch::sum(torch::exp(input), 1);
2707*da0073e9SAndroid Build Coastguard Worker 
2708*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(1)) {
2709*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(2)) {
2710*da0073e9SAndroid Build Coastguard Worker       for (const auto k : c10::irange(3)) {
2711*da0073e9SAndroid Build Coastguard Worker         for (const auto l : c10::irange(4)) {
2712*da0073e9SAndroid Build Coastguard Worker           auto expected = torch::exp(input[i][j][k][l]) / sum[i][k][l];
2713*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(output[i][j][k][l], expected));
2714*da0073e9SAndroid Build Coastguard Worker         }
2715*da0073e9SAndroid Build Coastguard Worker       }
2716*da0073e9SAndroid Build Coastguard Worker     }
2717*da0073e9SAndroid Build Coastguard Worker   }
2718*da0073e9SAndroid Build Coastguard Worker }
2719*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PReLU)2720*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PReLU) {
2721*da0073e9SAndroid Build Coastguard Worker   const auto num_parameters = 42;
2722*da0073e9SAndroid Build Coastguard Worker   const auto init = 0.42;
2723*da0073e9SAndroid Build Coastguard Worker 
2724*da0073e9SAndroid Build Coastguard Worker   PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)};
2725*da0073e9SAndroid Build Coastguard Worker 
2726*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(model->weight.sizes(), std::vector<int64_t>({num_parameters}));
2727*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
2728*da0073e9SAndroid Build Coastguard Worker       torch::allclose(model->weight, torch::full(num_parameters, init)));
2729*da0073e9SAndroid Build Coastguard Worker 
2730*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::rand({100, num_parameters}) * 200 - 100;
2731*da0073e9SAndroid Build Coastguard Worker   const auto y = model(x);
2732*da0073e9SAndroid Build Coastguard Worker   const auto s = y.sum();
2733*da0073e9SAndroid Build Coastguard Worker 
2734*da0073e9SAndroid Build Coastguard Worker   s.backward();
2735*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
2736*da0073e9SAndroid Build Coastguard Worker 
2737*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), x.ndimension());
2738*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), x.sizes());
2739*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x;
2740*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
2741*da0073e9SAndroid Build Coastguard Worker }
2742*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReLU)2743*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReLU) {
2744*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
2745*da0073e9SAndroid Build Coastguard Worker     const auto size = 3;
2746*da0073e9SAndroid Build Coastguard Worker     ReLU model(inplace);
2747*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2748*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size});
2749*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
2750*da0073e9SAndroid Build Coastguard Worker       x.requires_grad_(true);
2751*da0073e9SAndroid Build Coastguard Worker     }
2752*da0073e9SAndroid Build Coastguard Worker     auto x_orig = x.clone();
2753*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
2754*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
2755*da0073e9SAndroid Build Coastguard Worker 
2756*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
2757*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
2758*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2759*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig;
2760*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
2761*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
2762*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
2763*da0073e9SAndroid Build Coastguard Worker     } else {
2764*da0073e9SAndroid Build Coastguard Worker       s.backward();
2765*da0073e9SAndroid Build Coastguard Worker     }
2766*da0073e9SAndroid Build Coastguard Worker   }
2767*da0073e9SAndroid Build Coastguard Worker }
2768*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReLU6)2769*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReLU6) {
2770*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
2771*da0073e9SAndroid Build Coastguard Worker     const auto size = 3;
2772*da0073e9SAndroid Build Coastguard Worker     ReLU6 model(inplace);
2773*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2774*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size});
2775*da0073e9SAndroid Build Coastguard Worker     if (!inplace) {
2776*da0073e9SAndroid Build Coastguard Worker       x.requires_grad_(true);
2777*da0073e9SAndroid Build Coastguard Worker     }
2778*da0073e9SAndroid Build Coastguard Worker     auto x_orig = x.clone();
2779*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
2780*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
2781*da0073e9SAndroid Build Coastguard Worker 
2782*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
2783*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
2784*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2785*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig +
2786*da0073e9SAndroid Build Coastguard Worker         (x_orig > 6) * 6;
2787*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
2788*da0073e9SAndroid Build Coastguard Worker     if (inplace) {
2789*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(x, y_exp));
2790*da0073e9SAndroid Build Coastguard Worker     } else {
2791*da0073e9SAndroid Build Coastguard Worker       s.backward();
2792*da0073e9SAndroid Build Coastguard Worker     }
2793*da0073e9SAndroid Build Coastguard Worker   }
2794*da0073e9SAndroid Build Coastguard Worker }
2795*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,RReLU)2796*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, RReLU) {
2797*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2798*da0073e9SAndroid Build Coastguard Worker   for (const auto lower : {0.01, 0.1, 0.2}) {
2799*da0073e9SAndroid Build Coastguard Worker     for (const auto upper : {0.3, 0.4, 0.5}) {
2800*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
2801*da0073e9SAndroid Build Coastguard Worker         for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2802*da0073e9SAndroid Build Coastguard Worker           RReLU model{
2803*da0073e9SAndroid Build Coastguard Worker               RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
2804*da0073e9SAndroid Build Coastguard Worker           auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2805*da0073e9SAndroid Build Coastguard Worker           x.resize_({size, size, size});
2806*da0073e9SAndroid Build Coastguard Worker           if (!inplace) {
2807*da0073e9SAndroid Build Coastguard Worker             x.requires_grad_(true);
2808*da0073e9SAndroid Build Coastguard Worker           }
2809*da0073e9SAndroid Build Coastguard Worker           auto x_orig = x.clone();
2810*da0073e9SAndroid Build Coastguard Worker           auto y = model(x);
2811*da0073e9SAndroid Build Coastguard Worker           torch::Tensor s = y.sum();
2812*da0073e9SAndroid Build Coastguard Worker 
2813*da0073e9SAndroid Build Coastguard Worker           ASSERT_EQ(s.ndimension(), 0);
2814*da0073e9SAndroid Build Coastguard Worker           ASSERT_EQ(y.ndimension(), 3);
2815*da0073e9SAndroid Build Coastguard Worker           ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2816*da0073e9SAndroid Build Coastguard Worker           auto z =
2817*da0073e9SAndroid Build Coastguard Worker               ((x_orig >= 0) * (x_orig == y) +
2818*da0073e9SAndroid Build Coastguard Worker                (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) *
2819*da0073e9SAndroid Build Coastguard Worker               1.0;
2820*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
2821*da0073e9SAndroid Build Coastguard Worker           if (inplace) {
2822*da0073e9SAndroid Build Coastguard Worker             ASSERT_TRUE(torch::allclose(x, y));
2823*da0073e9SAndroid Build Coastguard Worker           } else {
2824*da0073e9SAndroid Build Coastguard Worker             s.backward();
2825*da0073e9SAndroid Build Coastguard Worker           }
2826*da0073e9SAndroid Build Coastguard Worker         }
2827*da0073e9SAndroid Build Coastguard Worker       }
2828*da0073e9SAndroid Build Coastguard Worker     }
2829*da0073e9SAndroid Build Coastguard Worker   }
2830*da0073e9SAndroid Build Coastguard Worker }
2831*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CELU)2832*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CELU) {
2833*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2834*da0073e9SAndroid Build Coastguard Worker   for (const auto inplace : {false, true}) {
2835*da0073e9SAndroid Build Coastguard Worker     for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
2836*da0073e9SAndroid Build Coastguard Worker       CELU model{CELUOptions().alpha(alpha).inplace(inplace)};
2837*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-10.0, 10.0, size * size * size);
2838*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
2839*da0073e9SAndroid Build Coastguard Worker       if (!inplace) {
2840*da0073e9SAndroid Build Coastguard Worker         x.requires_grad_(true);
2841*da0073e9SAndroid Build Coastguard Worker       }
2842*da0073e9SAndroid Build Coastguard Worker       auto x_orig = x.clone();
2843*da0073e9SAndroid Build Coastguard Worker       auto y = model(x);
2844*da0073e9SAndroid Build Coastguard Worker       torch::Tensor s = y.sum();
2845*da0073e9SAndroid Build Coastguard Worker 
2846*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(s.ndimension(), 0);
2847*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
2848*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2849*da0073e9SAndroid Build Coastguard Worker       auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2850*da0073e9SAndroid Build Coastguard Worker           torch::min(torch::zeros_like(x_orig),
2851*da0073e9SAndroid Build Coastguard Worker                      alpha * (torch::exp(x_orig / alpha) - 1.0));
2852*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
2853*da0073e9SAndroid Build Coastguard Worker       if (inplace) {
2854*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(x, y_exp));
2855*da0073e9SAndroid Build Coastguard Worker       } else {
2856*da0073e9SAndroid Build Coastguard Worker         s.backward();
2857*da0073e9SAndroid Build Coastguard Worker       }
2858*da0073e9SAndroid Build Coastguard Worker     }
2859*da0073e9SAndroid Build Coastguard Worker   }
2860*da0073e9SAndroid Build Coastguard Worker }
2861*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,GLU)2862*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, GLU) {
2863*da0073e9SAndroid Build Coastguard Worker   int64_t dim = 1;
2864*da0073e9SAndroid Build Coastguard Worker   GLU model(dim);
2865*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({4, 2}, torch::requires_grad());
2866*da0073e9SAndroid Build Coastguard Worker   auto output = model->forward(input);
2867*da0073e9SAndroid Build Coastguard Worker   auto input_size = input.sizes()[dim] / 2;
2868*da0073e9SAndroid Build Coastguard Worker   auto first_half = input.narrow(dim, 0, input_size);
2869*da0073e9SAndroid Build Coastguard Worker   auto second_half = input.narrow(dim, input_size, input_size);
2870*da0073e9SAndroid Build Coastguard Worker   auto expected = first_half * torch::sigmoid(second_half);
2871*da0073e9SAndroid Build Coastguard Worker   auto s = output.sum();
2872*da0073e9SAndroid Build Coastguard Worker   s.backward();
2873*da0073e9SAndroid Build Coastguard Worker 
2874*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(s.ndimension(), 0);
2875*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
2876*da0073e9SAndroid Build Coastguard Worker 
2877*da0073e9SAndroid Build Coastguard Worker   GLU model_default_options;
2878*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(model_default_options->forward(input).allclose(expected));
2879*da0073e9SAndroid Build Coastguard Worker }
2880*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,GELU)2881*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, GELU) {
2882*da0073e9SAndroid Build Coastguard Worker   GELU model(GELUOptions().approximate("none"));
2883*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::linspace(-3.0, 3.0, 100);
2884*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
2885*da0073e9SAndroid Build Coastguard Worker   const auto y = model(x);
2886*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2887*da0073e9SAndroid Build Coastguard Worker }
2888*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,TanhGELU)2889*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, TanhGELU) {
2890*da0073e9SAndroid Build Coastguard Worker   GELU model(GELUOptions().approximate("tanh"));
2891*da0073e9SAndroid Build Coastguard Worker   const auto x = torch::linspace(-3.0, 3.0, 100);
2892*da0073e9SAndroid Build Coastguard Worker   const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
2893*da0073e9SAndroid Build Coastguard Worker   const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
2894*da0073e9SAndroid Build Coastguard Worker   const auto y = model(x);
2895*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2896*da0073e9SAndroid Build Coastguard Worker }
2897*da0073e9SAndroid Build Coastguard Worker 
2898*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_F(ModulesTest,Mish)2899*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Mish) {
2900*da0073e9SAndroid Build Coastguard Worker   Mish model;
2901*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
2902*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x * x.exp().log1p().tanh();
2903*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
2904*da0073e9SAndroid Build Coastguard Worker 
2905*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
2906*da0073e9SAndroid Build Coastguard Worker }
2907*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Sigmoid)2908*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Sigmoid) {
2909*da0073e9SAndroid Build Coastguard Worker   Sigmoid model;
2910*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
2911*da0073e9SAndroid Build Coastguard Worker   auto y_exp = 1 / (1 + torch::exp(-x));
2912*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
2913*da0073e9SAndroid Build Coastguard Worker 
2914*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
2915*da0073e9SAndroid Build Coastguard Worker }
2916*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PixelShuffle)2917*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PixelShuffle) {
2918*da0073e9SAndroid Build Coastguard Worker   PixelShuffle module(/*upscale_factor=*/2);
2919*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
2920*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 19}, {-1, 2}},
2921*da0073e9SAndroid Build Coastguard Worker         {{7, 14}, {-3, 1}},
2922*da0073e9SAndroid Build Coastguard Worker         {{0, -2}, {-12, 14}},
2923*da0073e9SAndroid Build Coastguard Worker         {{-15, 0}, {-3, 9}}}},
2924*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
2925*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::tensor(
2926*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2927*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
2928*da0073e9SAndroid Build Coastguard Worker   auto y = module(x);
2929*da0073e9SAndroid Build Coastguard Worker 
2930*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
2931*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
2932*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.allclose(y_exp));
2933*da0073e9SAndroid Build Coastguard Worker }
2934*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PixelUnshuffle)2935*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PixelUnshuffle) {
2936*da0073e9SAndroid Build Coastguard Worker   PixelUnshuffle module(/*downscale_factor=*/2);
2937*da0073e9SAndroid Build Coastguard Worker   auto x = torch::tensor(
2938*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2939*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
2940*da0073e9SAndroid Build Coastguard Worker   auto y_exp = torch::tensor(
2941*da0073e9SAndroid Build Coastguard Worker       {{{{-17, 19}, {-1, 2}},
2942*da0073e9SAndroid Build Coastguard Worker         {{7, 14}, {-3, 1}},
2943*da0073e9SAndroid Build Coastguard Worker         {{0, -2}, {-12, 14}},
2944*da0073e9SAndroid Build Coastguard Worker         {{-15, 0}, {-3, 9}}}},
2945*da0073e9SAndroid Build Coastguard Worker       torch::kFloat);
2946*da0073e9SAndroid Build Coastguard Worker   auto y = module(x);
2947*da0073e9SAndroid Build Coastguard Worker 
2948*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 4);
2949*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
2950*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.allclose(y_exp));
2951*da0073e9SAndroid Build Coastguard Worker }
2952*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softplus)2953*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softplus) {
2954*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2955*da0073e9SAndroid Build Coastguard Worker   for (const auto beta : {0.5, 1.0, 2.0}) {
2956*da0073e9SAndroid Build Coastguard Worker     for (const auto threshold : {1.0, 3.0, 5.0}) {
2957*da0073e9SAndroid Build Coastguard Worker       Softplus model{SoftplusOptions().beta(beta).threshold(threshold)};
2958*da0073e9SAndroid Build Coastguard Worker       auto x = torch::linspace(-3.0, 3.0, 61);
2959*da0073e9SAndroid Build Coastguard Worker       x.resize_({size, size, size});
2960*da0073e9SAndroid Build Coastguard Worker       auto y_exp =
2961*da0073e9SAndroid Build Coastguard Worker           (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
2962*da0073e9SAndroid Build Coastguard Worker           (x > threshold) * x;
2963*da0073e9SAndroid Build Coastguard Worker       auto y = model(x);
2964*da0073e9SAndroid Build Coastguard Worker 
2965*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.ndimension(), 3);
2966*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2967*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(y, y_exp));
2968*da0073e9SAndroid Build Coastguard Worker     }
2969*da0073e9SAndroid Build Coastguard Worker   }
2970*da0073e9SAndroid Build Coastguard Worker }
2971*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softshrink)2972*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softshrink) {
2973*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
2974*da0073e9SAndroid Build Coastguard Worker   for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2975*da0073e9SAndroid Build Coastguard Worker     Softshrink model{/*lambda=*/lambda};
2976*da0073e9SAndroid Build Coastguard Worker     auto x = torch::linspace(-10.0, 10.0, size * size * size);
2977*da0073e9SAndroid Build Coastguard Worker     x.resize_({size, size, size}).set_requires_grad(true);
2978*da0073e9SAndroid Build Coastguard Worker     auto y = model(x);
2979*da0073e9SAndroid Build Coastguard Worker     torch::Tensor s = y.sum();
2980*da0073e9SAndroid Build Coastguard Worker 
2981*da0073e9SAndroid Build Coastguard Worker     s.backward();
2982*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
2983*da0073e9SAndroid Build Coastguard Worker 
2984*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
2985*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2986*da0073e9SAndroid Build Coastguard Worker     auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
2987*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
2988*da0073e9SAndroid Build Coastguard Worker   }
2989*da0073e9SAndroid Build Coastguard Worker }
2990*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Softsign)2991*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Softsign) {
2992*da0073e9SAndroid Build Coastguard Worker   Softsign model;
2993*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
2994*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x / (1 + x.abs());
2995*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
2996*da0073e9SAndroid Build Coastguard Worker 
2997*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
2998*da0073e9SAndroid Build Coastguard Worker }
2999*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Tanh)3000*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Tanh) {
3001*da0073e9SAndroid Build Coastguard Worker   Tanh model;
3002*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
3003*da0073e9SAndroid Build Coastguard Worker   auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp());
3004*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
3005*da0073e9SAndroid Build Coastguard Worker 
3006*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
3007*da0073e9SAndroid Build Coastguard Worker }
3008*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Tanhshrink)3009*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Tanhshrink) {
3010*da0073e9SAndroid Build Coastguard Worker   Tanhshrink model;
3011*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn(100) * 10;
3012*da0073e9SAndroid Build Coastguard Worker   auto y_exp = x - x.tanh();
3013*da0073e9SAndroid Build Coastguard Worker   auto y = model(x);
3014*da0073e9SAndroid Build Coastguard Worker 
3015*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(y, y_exp));
3016*da0073e9SAndroid Build Coastguard Worker }
3017*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Threshold)3018*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Threshold) {
3019*da0073e9SAndroid Build Coastguard Worker   const auto size = 3;
3020*da0073e9SAndroid Build Coastguard Worker   for (const auto threshold : {0.5, 1.0, 2.0}) {
3021*da0073e9SAndroid Build Coastguard Worker     for (const auto value : {0.5, 1.0, 2.0}) {
3022*da0073e9SAndroid Build Coastguard Worker       for (const auto inplace : {false, true}) {
3023*da0073e9SAndroid Build Coastguard Worker         Threshold model{ThresholdOptions(threshold, value).inplace(inplace)};
3024*da0073e9SAndroid Build Coastguard Worker         auto x = torch::linspace(-3.0, 3.0, 61);
3025*da0073e9SAndroid Build Coastguard Worker         x.resize_({size, size, size});
3026*da0073e9SAndroid Build Coastguard Worker         auto x_orig = x.clone();
3027*da0073e9SAndroid Build Coastguard Worker         auto y_exp =
3028*da0073e9SAndroid Build Coastguard Worker             (x_orig <= threshold) * value + (x_orig > threshold) * x_orig;
3029*da0073e9SAndroid Build Coastguard Worker         auto y = model(x);
3030*da0073e9SAndroid Build Coastguard Worker 
3031*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.ndimension(), 3);
3032*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
3033*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(torch::allclose(y, y_exp));
3034*da0073e9SAndroid Build Coastguard Worker         if (inplace) {
3035*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(torch::allclose(x, y_exp));
3036*da0073e9SAndroid Build Coastguard Worker         }
3037*da0073e9SAndroid Build Coastguard Worker       }
3038*da0073e9SAndroid Build Coastguard Worker     }
3039*da0073e9SAndroid Build Coastguard Worker   }
3040*da0073e9SAndroid Build Coastguard Worker }
3041*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Upsampling1D)3042*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Upsampling1D) {
3043*da0073e9SAndroid Build Coastguard Worker   {
3044*da0073e9SAndroid Build Coastguard Worker     Upsample model(UpsampleOptions()
3045*da0073e9SAndroid Build Coastguard Worker                        .size(std::vector<int64_t>({4}))
3046*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kNearest));
3047*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3048*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
3049*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::ones({1, 1, 4});
3050*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
3051*da0073e9SAndroid Build Coastguard Worker     s.backward();
3052*da0073e9SAndroid Build Coastguard Worker 
3053*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
3054*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3055*da0073e9SAndroid Build Coastguard Worker   }
3056*da0073e9SAndroid Build Coastguard Worker   {
3057*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
3058*da0073e9SAndroid Build Coastguard Worker       // test float scale factor up & down sampling
3059*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3060*da0073e9SAndroid Build Coastguard Worker         Upsample model(UpsampleOptions()
3061*da0073e9SAndroid Build Coastguard Worker                            .scale_factor(std::vector<double>({scale_factor}))
3062*da0073e9SAndroid Build Coastguard Worker                            .mode(torch::kLinear)
3063*da0073e9SAndroid Build Coastguard Worker                            .align_corners(align_corners));
3064*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3065*da0073e9SAndroid Build Coastguard Worker         auto output = model->forward(input);
3066*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
3067*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3068*da0073e9SAndroid Build Coastguard Worker         auto expected = torch::ones({1, 1, expected_size});
3069*da0073e9SAndroid Build Coastguard Worker         auto s = output.sum();
3070*da0073e9SAndroid Build Coastguard Worker         s.backward();
3071*da0073e9SAndroid Build Coastguard Worker 
3072*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
3073*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
3074*da0073e9SAndroid Build Coastguard Worker       }
3075*da0073e9SAndroid Build Coastguard Worker     }
3076*da0073e9SAndroid Build Coastguard Worker   }
3077*da0073e9SAndroid Build Coastguard Worker   {
3078*da0073e9SAndroid Build Coastguard Worker     // linear (1D) upsampling spatial invariance
3079*da0073e9SAndroid Build Coastguard Worker     Upsample model(UpsampleOptions()
3080*da0073e9SAndroid Build Coastguard Worker                        .scale_factor(std::vector<double>({3}))
3081*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kLinear)
3082*da0073e9SAndroid Build Coastguard Worker                        .align_corners(false));
3083*da0073e9SAndroid Build Coastguard Worker     auto input = torch::zeros({1, 1, 9});
3084*da0073e9SAndroid Build Coastguard Worker     input.narrow(2, 0, 4).normal_();
3085*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
3086*da0073e9SAndroid Build Coastguard Worker     auto expected = model->forward(input.narrow(2, 0, 5));
3087*da0073e9SAndroid Build Coastguard Worker 
3088*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output.narrow(2, 0, 15), expected));
3089*da0073e9SAndroid Build Coastguard Worker   }
3090*da0073e9SAndroid Build Coastguard Worker }
3091*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Upsampling2D)3092*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Upsampling2D) {
3093*da0073e9SAndroid Build Coastguard Worker   {
3094*da0073e9SAndroid Build Coastguard Worker     Upsample model(UpsampleOptions()
3095*da0073e9SAndroid Build Coastguard Worker                        .size(std::vector<int64_t>({4, 4}))
3096*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kNearest));
3097*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3098*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
3099*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::ones({1, 1, 4, 4});
3100*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
3101*da0073e9SAndroid Build Coastguard Worker     s.backward();
3102*da0073e9SAndroid Build Coastguard Worker 
3103*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
3104*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3105*da0073e9SAndroid Build Coastguard Worker   }
3106*da0073e9SAndroid Build Coastguard Worker   {
3107*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
3108*da0073e9SAndroid Build Coastguard Worker       // test float scale factor up & down sampling
3109*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3110*da0073e9SAndroid Build Coastguard Worker         Upsample model(
3111*da0073e9SAndroid Build Coastguard Worker             UpsampleOptions()
3112*da0073e9SAndroid Build Coastguard Worker                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3113*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBilinear)
3114*da0073e9SAndroid Build Coastguard Worker                 .align_corners(align_corners));
3115*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3116*da0073e9SAndroid Build Coastguard Worker         auto output = model->forward(input);
3117*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
3118*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3119*da0073e9SAndroid Build Coastguard Worker         auto expected = torch::ones({1, 1, expected_size, expected_size});
3120*da0073e9SAndroid Build Coastguard Worker         auto s = output.sum();
3121*da0073e9SAndroid Build Coastguard Worker         s.backward();
3122*da0073e9SAndroid Build Coastguard Worker 
3123*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
3124*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
3125*da0073e9SAndroid Build Coastguard Worker       }
3126*da0073e9SAndroid Build Coastguard Worker     }
3127*da0073e9SAndroid Build Coastguard Worker   }
3128*da0073e9SAndroid Build Coastguard Worker   {
3129*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
3130*da0073e9SAndroid Build Coastguard Worker       // test float scale factor up & down sampling
3131*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3132*da0073e9SAndroid Build Coastguard Worker         Upsample model(
3133*da0073e9SAndroid Build Coastguard Worker             UpsampleOptions()
3134*da0073e9SAndroid Build Coastguard Worker                 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3135*da0073e9SAndroid Build Coastguard Worker                 .mode(torch::kBicubic)
3136*da0073e9SAndroid Build Coastguard Worker                 .align_corners(align_corners));
3137*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3138*da0073e9SAndroid Build Coastguard Worker         auto output = model->forward(input);
3139*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
3140*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3141*da0073e9SAndroid Build Coastguard Worker         auto expected = torch::ones({1, 1, expected_size, expected_size});
3142*da0073e9SAndroid Build Coastguard Worker         auto s = output.sum();
3143*da0073e9SAndroid Build Coastguard Worker         s.backward();
3144*da0073e9SAndroid Build Coastguard Worker 
3145*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
3146*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
3147*da0073e9SAndroid Build Coastguard Worker       }
3148*da0073e9SAndroid Build Coastguard Worker     }
3149*da0073e9SAndroid Build Coastguard Worker   }
3150*da0073e9SAndroid Build Coastguard Worker }
3151*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,Upsampling3D)3152*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, Upsampling3D) {
3153*da0073e9SAndroid Build Coastguard Worker   {
3154*da0073e9SAndroid Build Coastguard Worker     Upsample model(UpsampleOptions()
3155*da0073e9SAndroid Build Coastguard Worker                        .size(std::vector<int64_t>({4, 4, 4}))
3156*da0073e9SAndroid Build Coastguard Worker                        .mode(torch::kNearest));
3157*da0073e9SAndroid Build Coastguard Worker     auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3158*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
3159*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::ones({1, 1, 4, 4, 4});
3160*da0073e9SAndroid Build Coastguard Worker     auto s = output.sum();
3161*da0073e9SAndroid Build Coastguard Worker     s.backward();
3162*da0073e9SAndroid Build Coastguard Worker 
3163*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(s.ndimension(), 0);
3164*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3165*da0073e9SAndroid Build Coastguard Worker   }
3166*da0073e9SAndroid Build Coastguard Worker   {
3167*da0073e9SAndroid Build Coastguard Worker     for (const auto align_corners : {true, false}) {
3168*da0073e9SAndroid Build Coastguard Worker       // test float scale factor up & down sampling
3169*da0073e9SAndroid Build Coastguard Worker       for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3170*da0073e9SAndroid Build Coastguard Worker         Upsample model(UpsampleOptions()
3171*da0073e9SAndroid Build Coastguard Worker                            .scale_factor(std::vector<double>(
3172*da0073e9SAndroid Build Coastguard Worker                                {scale_factor, scale_factor, scale_factor}))
3173*da0073e9SAndroid Build Coastguard Worker                            .mode(torch::kTrilinear)
3174*da0073e9SAndroid Build Coastguard Worker                            .align_corners(align_corners));
3175*da0073e9SAndroid Build Coastguard Worker         auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3176*da0073e9SAndroid Build Coastguard Worker         auto output = model->forward(input);
3177*da0073e9SAndroid Build Coastguard Worker         auto expected_size =
3178*da0073e9SAndroid Build Coastguard Worker             static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3179*da0073e9SAndroid Build Coastguard Worker         auto expected =
3180*da0073e9SAndroid Build Coastguard Worker             torch::ones({1, 1, expected_size, expected_size, expected_size});
3181*da0073e9SAndroid Build Coastguard Worker         auto s = output.sum();
3182*da0073e9SAndroid Build Coastguard Worker         s.backward();
3183*da0073e9SAndroid Build Coastguard Worker 
3184*da0073e9SAndroid Build Coastguard Worker         ASSERT_EQ(s.ndimension(), 0);
3185*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(output.allclose(expected));
3186*da0073e9SAndroid Build Coastguard Worker       }
3187*da0073e9SAndroid Build Coastguard Worker     }
3188*da0073e9SAndroid Build Coastguard Worker   }
3189*da0073e9SAndroid Build Coastguard Worker }
3190*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CTCLoss)3191*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CTCLoss) {
3192*da0073e9SAndroid Build Coastguard Worker   CTCLoss loss{CTCLossOptions().reduction(torch::kNone)};
3193*da0073e9SAndroid Build Coastguard Worker   const auto target_lengths = torch::tensor({0, 0, 0});
3194*da0073e9SAndroid Build Coastguard Worker   const auto input_lengths = torch::tensor({50, 50, 50});
3195*da0073e9SAndroid Build Coastguard Worker   const auto targets =
3196*da0073e9SAndroid Build Coastguard Worker       torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
3197*da0073e9SAndroid Build Coastguard Worker   const auto log_probs =
3198*da0073e9SAndroid Build Coastguard Worker       torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
3199*da0073e9SAndroid Build Coastguard Worker   const auto output =
3200*da0073e9SAndroid Build Coastguard Worker       loss->forward(log_probs, targets, input_lengths, target_lengths);
3201*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.ge(0).all().item<bool>());
3202*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(
3203*da0073e9SAndroid Build Coastguard Worker       -log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
3204*da0073e9SAndroid Build Coastguard Worker }
3205*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PoissonNLLLoss)3206*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PoissonNLLLoss) {
3207*da0073e9SAndroid Build Coastguard Worker   const auto input = torch::tensor({0.5, 1.5, 2.5});
3208*da0073e9SAndroid Build Coastguard Worker   const auto target = torch::tensor({1., 2., 3.});
3209*da0073e9SAndroid Build Coastguard Worker   const auto component_wise_loss = torch::exp(input) - target * input;
3210*da0073e9SAndroid Build Coastguard Worker   {
3211*da0073e9SAndroid Build Coastguard Worker     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)};
3212*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
3213*da0073e9SAndroid Build Coastguard Worker         torch::allclose(component_wise_loss, loss->forward(input, target)));
3214*da0073e9SAndroid Build Coastguard Worker   }
3215*da0073e9SAndroid Build Coastguard Worker   {
3216*da0073e9SAndroid Build Coastguard Worker     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)};
3217*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3218*da0073e9SAndroid Build Coastguard Worker         torch::sum(component_wise_loss), loss->forward(input, target)));
3219*da0073e9SAndroid Build Coastguard Worker   }
3220*da0073e9SAndroid Build Coastguard Worker   {
3221*da0073e9SAndroid Build Coastguard Worker     PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)};
3222*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3223*da0073e9SAndroid Build Coastguard Worker         torch::mean(component_wise_loss), loss->forward(input, target)));
3224*da0073e9SAndroid Build Coastguard Worker   }
3225*da0073e9SAndroid Build Coastguard Worker }
3226*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MarginRankingLoss)3227*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MarginRankingLoss) {
3228*da0073e9SAndroid Build Coastguard Worker   {
3229*da0073e9SAndroid Build Coastguard Worker     MarginRankingLoss loss;
3230*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
3231*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
3232*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
3233*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3234*da0073e9SAndroid Build Coastguard Worker         loss->forward(input1, input2, target),
3235*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2)).clamp(0).mean()));
3236*da0073e9SAndroid Build Coastguard Worker   }
3237*da0073e9SAndroid Build Coastguard Worker   {
3238*da0073e9SAndroid Build Coastguard Worker     MarginRankingLoss loss{
3239*da0073e9SAndroid Build Coastguard Worker         MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)};
3240*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
3241*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
3242*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
3243*da0073e9SAndroid Build Coastguard Worker     const auto margin = 0.5;
3244*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3245*da0073e9SAndroid Build Coastguard Worker         loss->forward(input1, input2, target),
3246*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2) + margin).clamp(0).sum()));
3247*da0073e9SAndroid Build Coastguard Worker   }
3248*da0073e9SAndroid Build Coastguard Worker   {
3249*da0073e9SAndroid Build Coastguard Worker     MarginRankingLoss loss{
3250*da0073e9SAndroid Build Coastguard Worker         MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)};
3251*da0073e9SAndroid Build Coastguard Worker     const auto input1 = torch::randn(15) * 10;
3252*da0073e9SAndroid Build Coastguard Worker     const auto input2 = torch::randn(15) * 10;
3253*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::randn(15).sign();
3254*da0073e9SAndroid Build Coastguard Worker     const auto margin = 0.5;
3255*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3256*da0073e9SAndroid Build Coastguard Worker         loss->forward(input1, input2, target),
3257*da0073e9SAndroid Build Coastguard Worker         (-target * (input1 - input2) + margin).clamp(0).mean()));
3258*da0073e9SAndroid Build Coastguard Worker   }
3259*da0073e9SAndroid Build Coastguard Worker }
3260*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,BCEWithLogitsLoss)3261*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, BCEWithLogitsLoss) {
3262*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits raises if target and input are different size
3263*da0073e9SAndroid Build Coastguard Worker     {
3264*da0073e9SAndroid Build Coastguard Worker       const auto target = torch::rand(5);
3265*da0073e9SAndroid Build Coastguard Worker       const auto input = torch::rand({5, 1});
3266*da0073e9SAndroid Build Coastguard Worker       ASSERT_THROWS_WITH(
3267*da0073e9SAndroid Build Coastguard Worker           BCEWithLogitsLoss()(input, target), "must be the same as input size");
3268*da0073e9SAndroid Build Coastguard Worker     }
3269*da0073e9SAndroid Build Coastguard Worker     {
3270*da0073e9SAndroid Build Coastguard Worker       const auto target = torch::rand({5, 1});
3271*da0073e9SAndroid Build Coastguard Worker       const auto input = torch::rand(5);
3272*da0073e9SAndroid Build Coastguard Worker       ASSERT_THROWS_WITH(
3273*da0073e9SAndroid Build Coastguard Worker           BCEWithLogitsLoss()(input, target), "must be the same as input size");
3274*da0073e9SAndroid Build Coastguard Worker     }
3275*da0073e9SAndroid Build Coastguard Worker   }
3276*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits gives same result as sigmoid and bce loss
3277*da0073e9SAndroid Build Coastguard Worker     auto sigmoid = Sigmoid();
3278*da0073e9SAndroid Build Coastguard Worker 
3279*da0073e9SAndroid Build Coastguard Worker     auto target = torch::rand({64, 4});
3280*da0073e9SAndroid Build Coastguard Worker     auto output = torch::rand({64, 4}) - 0.5;
3281*da0073e9SAndroid Build Coastguard Worker 
3282*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3283*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss()(output, target),
3284*da0073e9SAndroid Build Coastguard Worker         BCELoss()(sigmoid(output), target)));
3285*da0073e9SAndroid Build Coastguard Worker 
3286*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::rand(4);
3287*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3288*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3289*da0073e9SAndroid Build Coastguard Worker             output, target),
3290*da0073e9SAndroid Build Coastguard Worker         BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3291*da0073e9SAndroid Build Coastguard Worker 
3292*da0073e9SAndroid Build Coastguard Worker     target = torch::zeros({4, 1}, torch::kFloat);
3293*da0073e9SAndroid Build Coastguard Worker     output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3294*da0073e9SAndroid Build Coastguard Worker 
3295*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3296*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss()(output, target),
3297*da0073e9SAndroid Build Coastguard Worker         BCELoss()(sigmoid(output), target)));
3298*da0073e9SAndroid Build Coastguard Worker 
3299*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3300*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))(
3301*da0073e9SAndroid Build Coastguard Worker             output, target),
3302*da0073e9SAndroid Build Coastguard Worker         BCELoss(BCELossOptions().reduction(torch::kNone))(
3303*da0073e9SAndroid Build Coastguard Worker             sigmoid(output), target)));
3304*da0073e9SAndroid Build Coastguard Worker 
3305*da0073e9SAndroid Build Coastguard Worker     weight = torch::rand({1}, torch::kFloat);
3306*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3307*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3308*da0073e9SAndroid Build Coastguard Worker             output, target),
3309*da0073e9SAndroid Build Coastguard Worker         BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3310*da0073e9SAndroid Build Coastguard Worker   }
3311*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits has correct grad at zero
3312*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3313*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::zeros({3, 1});
3314*da0073e9SAndroid Build Coastguard Worker     BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))(
3315*da0073e9SAndroid Build Coastguard Worker         output, target)
3316*da0073e9SAndroid Build Coastguard Worker         .backward();
3317*da0073e9SAndroid Build Coastguard Worker     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3318*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3319*da0073e9SAndroid Build Coastguard Worker   }
3320*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits broadcasts weights
3321*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({16, 4});
3322*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({16, 4}) - 0.5;
3323*da0073e9SAndroid Build Coastguard Worker 
3324*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::rand(4);
3325*da0073e9SAndroid Build Coastguard Worker     auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3326*da0073e9SAndroid Build Coastguard Worker         output, target);
3327*da0073e9SAndroid Build Coastguard Worker 
3328*da0073e9SAndroid Build Coastguard Worker     weight = weight.expand({16, 4}).contiguous();
3329*da0073e9SAndroid Build Coastguard Worker     auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3330*da0073e9SAndroid Build Coastguard Worker         output, target);
3331*da0073e9SAndroid Build Coastguard Worker 
3332*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3333*da0073e9SAndroid Build Coastguard Worker 
3334*da0073e9SAndroid Build Coastguard Worker     weight = torch::rand({16, 1});
3335*da0073e9SAndroid Build Coastguard Worker     out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3336*da0073e9SAndroid Build Coastguard Worker         output, target);
3337*da0073e9SAndroid Build Coastguard Worker 
3338*da0073e9SAndroid Build Coastguard Worker     weight = weight.expand({16, 4}).contiguous();
3339*da0073e9SAndroid Build Coastguard Worker     out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3340*da0073e9SAndroid Build Coastguard Worker         output, target);
3341*da0073e9SAndroid Build Coastguard Worker 
3342*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3343*da0073e9SAndroid Build Coastguard Worker   }
3344*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits ones in pos weights are the same as none
3345*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({64, 4});
3346*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({64, 4}) - 0.5;
3347*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::ones({64, 4});
3348*da0073e9SAndroid Build Coastguard Worker 
3349*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3350*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss()(output, target),
3351*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))(
3352*da0073e9SAndroid Build Coastguard Worker             output, target)));
3353*da0073e9SAndroid Build Coastguard Worker   }
3354*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits broadcasts pos weights
3355*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::rand({64, 4});
3356*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::rand({64, 4}) - 0.5;
3357*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::rand(4);
3358*da0073e9SAndroid Build Coastguard Worker     const auto out1 = BCEWithLogitsLoss(
3359*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3360*da0073e9SAndroid Build Coastguard Worker 
3361*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight1 = pos_weight.expand({1, 4});
3362*da0073e9SAndroid Build Coastguard Worker     const auto out2 = BCEWithLogitsLoss(
3363*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3364*da0073e9SAndroid Build Coastguard Worker 
3365*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight2 = pos_weight.expand({64, 4});
3366*da0073e9SAndroid Build Coastguard Worker     const auto out3 = BCEWithLogitsLoss(
3367*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3368*da0073e9SAndroid Build Coastguard Worker 
3369*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out2));
3370*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(out1, out3));
3371*da0073e9SAndroid Build Coastguard Worker   }
3372*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits with pos weight has correct grad at zero
3373*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::zeros({3, 1}, torch::requires_grad());
3374*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::zeros({3, 1});
3375*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::ones({3, 1});
3376*da0073e9SAndroid Build Coastguard Worker     BCEWithLogitsLoss(BCEWithLogitsLossOptions()
3377*da0073e9SAndroid Build Coastguard Worker                           .pos_weight(pos_weight)
3378*da0073e9SAndroid Build Coastguard Worker                           .reduction(torch::kSum))(output, target)
3379*da0073e9SAndroid Build Coastguard Worker         .backward();
3380*da0073e9SAndroid Build Coastguard Worker     const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3381*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3382*da0073e9SAndroid Build Coastguard Worker     const auto grad = output.grad();
3383*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(grad, expected_grad));
3384*da0073e9SAndroid Build Coastguard Worker   }
3385*da0073e9SAndroid Build Coastguard Worker   { // test BCE with logits stability
3386*da0073e9SAndroid Build Coastguard Worker     const auto output = torch::tensor({0., -120.});
3387*da0073e9SAndroid Build Coastguard Worker     const auto target = torch::tensor({0., 1.});
3388*da0073e9SAndroid Build Coastguard Worker     const auto pos_weight = torch::tensor({1., 1.});
3389*da0073e9SAndroid Build Coastguard Worker 
3390*da0073e9SAndroid Build Coastguard Worker     const auto out1 = BCEWithLogitsLoss()(output, target);
3391*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3392*da0073e9SAndroid Build Coastguard Worker 
3393*da0073e9SAndroid Build Coastguard Worker     const auto out2 = BCEWithLogitsLoss(
3394*da0073e9SAndroid Build Coastguard Worker         BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3395*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3396*da0073e9SAndroid Build Coastguard Worker   }
3397*da0073e9SAndroid Build Coastguard Worker }
3398*da0073e9SAndroid Build Coastguard Worker 
3399*da0073e9SAndroid Build Coastguard Worker namespace detail {
3400*da0073e9SAndroid Build Coastguard Worker 
3401*da0073e9SAndroid Build Coastguard Worker namespace F = torch::nn::functional;
3402*da0073e9SAndroid Build Coastguard Worker 
_batchmatmul(const torch::Tensor & a,const torch::Tensor & b)3403*da0073e9SAndroid Build Coastguard Worker torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) {
3404*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0));
3405*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1));
3406*da0073e9SAndroid Build Coastguard Worker   auto retval = torch::zeros(
3407*da0073e9SAndroid Build Coastguard Worker       {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32);
3408*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(a.size(0))) {
3409*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(a.size(1))) {
3410*da0073e9SAndroid Build Coastguard Worker       retval[i][j] = torch::matmul(a[i][j], b[i][j]);
3411*da0073e9SAndroid Build Coastguard Worker     }
3412*da0073e9SAndroid Build Coastguard Worker   }
3413*da0073e9SAndroid Build Coastguard Worker   return retval;
3414*da0073e9SAndroid Build Coastguard Worker }
3415*da0073e9SAndroid Build Coastguard Worker 
_softmax(const torch::Tensor & x)3416*da0073e9SAndroid Build Coastguard Worker torch::Tensor _softmax(const torch::Tensor& x) {
3417*da0073e9SAndroid Build Coastguard Worker   auto output = torch::zeros(x.sizes());
3418*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(x.size(0))) {
3419*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(x.size(1))) {
3420*da0073e9SAndroid Build Coastguard Worker       for (const auto k : c10::irange(x.size(2))) {
3421*da0073e9SAndroid Build Coastguard Worker         const auto& x_curr = x[i][j][k];
3422*da0073e9SAndroid Build Coastguard Worker         const auto e_x = torch::exp(x_curr - torch::max(x_curr));
3423*da0073e9SAndroid Build Coastguard Worker         output[i][j][k] = e_x / torch::sum(e_x);
3424*da0073e9SAndroid Build Coastguard Worker       }
3425*da0073e9SAndroid Build Coastguard Worker     }
3426*da0073e9SAndroid Build Coastguard Worker   }
3427*da0073e9SAndroid Build Coastguard Worker   return output;
3428*da0073e9SAndroid Build Coastguard Worker }
3429*da0073e9SAndroid Build Coastguard Worker 
_scaled_dot_attn_ref(const torch::Tensor & Q,const torch::Tensor & K,const torch::Tensor & V,at::IntArrayRef dims,const torch::Tensor & unseen_mask={},const torch::Tensor & key_padding_mask={},bool average_attn_weights=true)3430*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> _scaled_dot_attn_ref(
3431*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& Q,
3432*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& K,
3433*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& V,
3434*da0073e9SAndroid Build Coastguard Worker     at::IntArrayRef dims,
3435*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& unseen_mask = {},
3436*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& key_padding_mask = {},
3437*da0073e9SAndroid Build Coastguard Worker     bool average_attn_weights = true) {
3438*da0073e9SAndroid Build Coastguard Worker   auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3]));
3439*da0073e9SAndroid Build Coastguard Worker   const auto b1 = QKT.size(0);
3440*da0073e9SAndroid Build Coastguard Worker   const auto b2 = QKT.size(1);
3441*da0073e9SAndroid Build Coastguard Worker   const auto s1 = QKT.size(2);
3442*da0073e9SAndroid Build Coastguard Worker   const auto s2 = QKT.size(3);
3443*da0073e9SAndroid Build Coastguard Worker   if (unseen_mask.defined() || key_padding_mask.defined()) {
3444*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(b1)) {
3445*da0073e9SAndroid Build Coastguard Worker       for (const auto j : c10::irange(b2)) {
3446*da0073e9SAndroid Build Coastguard Worker         for (const auto m : c10::irange(s1)) {
3447*da0073e9SAndroid Build Coastguard Worker           for (const auto n : c10::irange(s2)) {
3448*da0073e9SAndroid Build Coastguard Worker             if (unseen_mask.defined() &&
3449*da0073e9SAndroid Build Coastguard Worker                 unseen_mask[m][n].item<double>() == 0) {
3450*da0073e9SAndroid Build Coastguard Worker               QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3451*da0073e9SAndroid Build Coastguard Worker             }
3452*da0073e9SAndroid Build Coastguard Worker             if (key_padding_mask.defined() &&
3453*da0073e9SAndroid Build Coastguard Worker                 key_padding_mask[i][n].item<double>() != 0) {
3454*da0073e9SAndroid Build Coastguard Worker               QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3455*da0073e9SAndroid Build Coastguard Worker             }
3456*da0073e9SAndroid Build Coastguard Worker           }
3457*da0073e9SAndroid Build Coastguard Worker         }
3458*da0073e9SAndroid Build Coastguard Worker       }
3459*da0073e9SAndroid Build Coastguard Worker     }
3460*da0073e9SAndroid Build Coastguard Worker   }
3461*da0073e9SAndroid Build Coastguard Worker   auto reference = _softmax(QKT);
3462*da0073e9SAndroid Build Coastguard Worker   auto ref_attn_weight = reference;
3463*da0073e9SAndroid Build Coastguard Worker   if (average_attn_weights) {
3464*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(bugprone-argument-comment)
3465*da0073e9SAndroid Build Coastguard Worker     ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2;
3466*da0073e9SAndroid Build Coastguard Worker   }
3467*da0073e9SAndroid Build Coastguard Worker   reference = _batchmatmul(reference, V);
3468*da0073e9SAndroid Build Coastguard Worker   return std::tie(reference, ref_attn_weight);
3469*da0073e9SAndroid Build Coastguard Worker }
3470*da0073e9SAndroid Build Coastguard Worker 
_split_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3471*da0073e9SAndroid Build Coastguard Worker torch::Tensor _split_heads_ref(
3472*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& X,
3473*da0073e9SAndroid Build Coastguard Worker     at::IntArrayRef dims,
3474*da0073e9SAndroid Build Coastguard Worker     int nheads,
3475*da0073e9SAndroid Build Coastguard Worker     int d_head) {
3476*da0073e9SAndroid Build Coastguard Worker   auto X_split = X.reshape({dims[0], dims[1], nheads, d_head});
3477*da0073e9SAndroid Build Coastguard Worker   auto X_split_transposed = X_split.permute({0, 2, 1, 3});
3478*da0073e9SAndroid Build Coastguard Worker   return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head});
3479*da0073e9SAndroid Build Coastguard Worker }
3480*da0073e9SAndroid Build Coastguard Worker 
_combine_heads_ref(const torch::Tensor & X,at::IntArrayRef dims,int nheads,int d_head)3481*da0073e9SAndroid Build Coastguard Worker torch::Tensor _combine_heads_ref(
3482*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor& X,
3483*da0073e9SAndroid Build Coastguard Worker     at::IntArrayRef dims,
3484*da0073e9SAndroid Build Coastguard Worker     int nheads,
3485*da0073e9SAndroid Build Coastguard Worker     int d_head) {
3486*da0073e9SAndroid Build Coastguard Worker   auto X_transposed = X.permute({0, 2, 1, 3});
3487*da0073e9SAndroid Build Coastguard Worker   auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head});
3488*da0073e9SAndroid Build Coastguard Worker   return reference;
3489*da0073e9SAndroid Build Coastguard Worker }
3490*da0073e9SAndroid Build Coastguard Worker 
_fc(torch::Tensor X,torch::Tensor X_weight,torch::Tensor X_bias)3491*da0073e9SAndroid Build Coastguard Worker torch::Tensor _fc(
3492*da0073e9SAndroid Build Coastguard Worker     torch::Tensor X,
3493*da0073e9SAndroid Build Coastguard Worker     torch::Tensor X_weight,
3494*da0073e9SAndroid Build Coastguard Worker     torch::Tensor X_bias) {
3495*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3496*da0073e9SAndroid Build Coastguard Worker   auto X_fc_b = X_bias;
3497*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3498*da0073e9SAndroid Build Coastguard Worker   auto X_fc_w = X_weight;
3499*da0073e9SAndroid Build Coastguard Worker   return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b;
3500*da0073e9SAndroid Build Coastguard Worker }
3501*da0073e9SAndroid Build Coastguard Worker 
_multihead_attn_test_helper(bool add_key_padding_mask=false,bool add_bias_kv=false,bool add_zero_attn=false,bool saved_kv=false,bool same_embed_dim=false,bool average_attn_weights=true)3502*da0073e9SAndroid Build Coastguard Worker void _multihead_attn_test_helper(
3503*da0073e9SAndroid Build Coastguard Worker     bool add_key_padding_mask = false,
3504*da0073e9SAndroid Build Coastguard Worker     bool add_bias_kv = false,
3505*da0073e9SAndroid Build Coastguard Worker     bool add_zero_attn = false,
3506*da0073e9SAndroid Build Coastguard Worker     bool saved_kv = false,
3507*da0073e9SAndroid Build Coastguard Worker     bool same_embed_dim = false,
3508*da0073e9SAndroid Build Coastguard Worker     bool average_attn_weights = true) {
3509*da0073e9SAndroid Build Coastguard Worker   std::random_device device;
3510*da0073e9SAndroid Build Coastguard Worker   std::mt19937 generator(device());
3511*da0073e9SAndroid Build Coastguard Worker   std::uniform_int_distribution<int> d_2_10(2, 10);
3512*da0073e9SAndroid Build Coastguard Worker   std::uniform_int_distribution<int> d_3_10(3, 10);
3513*da0073e9SAndroid Build Coastguard Worker   bool registration_checked = false;
3514*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(100)) {
3515*da0073e9SAndroid Build Coastguard Worker     (void)i; // Suppress unused variable warning
3516*da0073e9SAndroid Build Coastguard Worker     const auto batch_sz = d_2_10(generator);
3517*da0073e9SAndroid Build Coastguard Worker     const auto seq_len = d_2_10(generator);
3518*da0073e9SAndroid Build Coastguard Worker     const auto d_head = d_3_10(generator);
3519*da0073e9SAndroid Build Coastguard Worker     const auto nheads = d_3_10(generator);
3520*da0073e9SAndroid Build Coastguard Worker     const auto d_model = d_head * nheads;
3521*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3522*da0073e9SAndroid Build Coastguard Worker     int kv_dim;
3523*da0073e9SAndroid Build Coastguard Worker     if (same_embed_dim) {
3524*da0073e9SAndroid Build Coastguard Worker       kv_dim = d_model;
3525*da0073e9SAndroid Build Coastguard Worker     } else {
3526*da0073e9SAndroid Build Coastguard Worker       std::uniform_int_distribution<int> d(5, 20);
3527*da0073e9SAndroid Build Coastguard Worker       kv_dim = d(generator);
3528*da0073e9SAndroid Build Coastguard Worker       while (kv_dim == d_model) {
3529*da0073e9SAndroid Build Coastguard Worker         kv_dim = d(generator);
3530*da0073e9SAndroid Build Coastguard Worker       }
3531*da0073e9SAndroid Build Coastguard Worker     }
3532*da0073e9SAndroid Build Coastguard Worker     std::vector<int64_t> dims{batch_sz, seq_len, kv_dim};
3533*da0073e9SAndroid Build Coastguard Worker     torch::Tensor saved_k;
3534*da0073e9SAndroid Build Coastguard Worker     torch::Tensor saved_k_tensor;
3535*da0073e9SAndroid Build Coastguard Worker     torch::Tensor saved_v;
3536*da0073e9SAndroid Build Coastguard Worker     torch::Tensor saved_v_tensor;
3537*da0073e9SAndroid Build Coastguard Worker     if (saved_kv) {
3538*da0073e9SAndroid Build Coastguard Worker       saved_k = torch::rand({batch_sz * nheads, seq_len, d_head});
3539*da0073e9SAndroid Build Coastguard Worker       saved_k_tensor = saved_k;
3540*da0073e9SAndroid Build Coastguard Worker       saved_v = torch::rand({batch_sz * nheads, seq_len, d_head});
3541*da0073e9SAndroid Build Coastguard Worker       saved_v_tensor = saved_v;
3542*da0073e9SAndroid Build Coastguard Worker     }
3543*da0073e9SAndroid Build Coastguard Worker     torch::Tensor key_padding_mask;
3544*da0073e9SAndroid Build Coastguard Worker     torch::Tensor key_padding_mask_tensor;
3545*da0073e9SAndroid Build Coastguard Worker     if (add_key_padding_mask) {
3546*da0073e9SAndroid Build Coastguard Worker       const auto seq_mask = torch::randint(0, 2, {1, seq_len});
3547*da0073e9SAndroid Build Coastguard Worker       key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1;
3548*da0073e9SAndroid Build Coastguard Worker       key_padding_mask_tensor = key_padding_mask;
3549*da0073e9SAndroid Build Coastguard Worker     }
3550*da0073e9SAndroid Build Coastguard Worker     const auto decoder_state = torch::rand({batch_sz, d_model});
3551*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor K = torch::rand(dims);
3552*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3553*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor V = K;
3554*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor Q =
3555*da0073e9SAndroid Build Coastguard Worker         decoder_state.clone().resize_({batch_sz, 1, d_model});
3556*da0073e9SAndroid Build Coastguard Worker     auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat);
3557*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor attn_mask_tensor = attn_mask.clone();
3558*da0073e9SAndroid Build Coastguard Worker     attn_mask_tensor.masked_fill_(
3559*da0073e9SAndroid Build Coastguard Worker         attn_mask_tensor == 0, -std::numeric_limits<double>::infinity());
3560*da0073e9SAndroid Build Coastguard Worker     attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0));
3561*da0073e9SAndroid Build Coastguard Worker 
3562*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3563*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor decoder_state_tensor = decoder_state;
3564*da0073e9SAndroid Build Coastguard Worker     const torch::Tensor source_hid_tensor = K.transpose(0, 1);
3565*da0073e9SAndroid Build Coastguard Worker 
3566*da0073e9SAndroid Build Coastguard Worker     const auto options = MultiheadAttentionOptions(d_model, nheads)
3567*da0073e9SAndroid Build Coastguard Worker                              .add_bias_kv(add_bias_kv)
3568*da0073e9SAndroid Build Coastguard Worker                              .add_zero_attn(add_zero_attn)
3569*da0073e9SAndroid Build Coastguard Worker                              .kdim(kv_dim)
3570*da0073e9SAndroid Build Coastguard Worker                              .vdim(kv_dim);
3571*da0073e9SAndroid Build Coastguard Worker     const auto multihead_attn_module = MultiheadAttention(options);
3572*da0073e9SAndroid Build Coastguard Worker 
3573*da0073e9SAndroid Build Coastguard Worker     if (!registration_checked) {
3574*da0073e9SAndroid Build Coastguard Worker       // make sure parameters are all registered correctly
3575*da0073e9SAndroid Build Coastguard Worker       auto named_parameters = multihead_attn_module->named_parameters();
3576*da0073e9SAndroid Build Coastguard Worker       if (same_embed_dim) {
3577*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("in_proj_weight"));
3578*da0073e9SAndroid Build Coastguard Worker       } else {
3579*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("q_proj_weight"));
3580*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("k_proj_weight"));
3581*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("v_proj_weight"));
3582*da0073e9SAndroid Build Coastguard Worker       }
3583*da0073e9SAndroid Build Coastguard Worker       if (add_bias_kv) {
3584*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("bias_k"));
3585*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(named_parameters.contains("bias_v"));
3586*da0073e9SAndroid Build Coastguard Worker       }
3587*da0073e9SAndroid Build Coastguard Worker       // make sure sub modules are all registered correctly
3588*da0073e9SAndroid Build Coastguard Worker       auto submodules = multihead_attn_module->named_children();
3589*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(submodules.contains("out_proj"));
3590*da0073e9SAndroid Build Coastguard Worker       registration_checked = true;
3591*da0073e9SAndroid Build Coastguard Worker     }
3592*da0073e9SAndroid Build Coastguard Worker 
3593*da0073e9SAndroid Build Coastguard Worker     torch::Tensor bias_k;
3594*da0073e9SAndroid Build Coastguard Worker     torch::Tensor bias_v;
3595*da0073e9SAndroid Build Coastguard Worker     if (add_bias_kv) {
3596*da0073e9SAndroid Build Coastguard Worker       bias_k = multihead_attn_module->bias_k.detach();
3597*da0073e9SAndroid Build Coastguard Worker       bias_v = multihead_attn_module->bias_v.detach();
3598*da0073e9SAndroid Build Coastguard Worker     } else {
3599*da0073e9SAndroid Build Coastguard Worker       bias_k.reset();
3600*da0073e9SAndroid Build Coastguard Worker       bias_v.reset();
3601*da0073e9SAndroid Build Coastguard Worker     }
3602*da0073e9SAndroid Build Coastguard Worker 
3603*da0073e9SAndroid Build Coastguard Worker     torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1);
3604*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3605*da0073e9SAndroid Build Coastguard Worker     torch::Tensor _V = source_hid_tensor;
3606*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3607*da0073e9SAndroid Build Coastguard Worker     torch::Tensor _K = source_hid_tensor;
3608*da0073e9SAndroid Build Coastguard Worker 
3609*da0073e9SAndroid Build Coastguard Worker     torch::Tensor result;
3610*da0073e9SAndroid Build Coastguard Worker     torch::Tensor result_weight;
3611*da0073e9SAndroid Build Coastguard Worker     if (multihead_attn_module->_qkv_same_embed_dim) {
3612*da0073e9SAndroid Build Coastguard Worker       std::tie(result, result_weight) = F::multi_head_attention_forward(
3613*da0073e9SAndroid Build Coastguard Worker           _Q,
3614*da0073e9SAndroid Build Coastguard Worker           _K,
3615*da0073e9SAndroid Build Coastguard Worker           _V,
3616*da0073e9SAndroid Build Coastguard Worker           F::MultiheadAttentionForwardFuncOptions(
3617*da0073e9SAndroid Build Coastguard Worker               /*embed_dim_to_check=*/d_model,
3618*da0073e9SAndroid Build Coastguard Worker               /*num_heads=*/nheads,
3619*da0073e9SAndroid Build Coastguard Worker               /*in_proj_weight=*/multihead_attn_module->in_proj_weight,
3620*da0073e9SAndroid Build Coastguard Worker               /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3621*da0073e9SAndroid Build Coastguard Worker               /*bias_k=*/multihead_attn_module->bias_k,
3622*da0073e9SAndroid Build Coastguard Worker               /*bias_v=*/multihead_attn_module->bias_v,
3623*da0073e9SAndroid Build Coastguard Worker               /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3624*da0073e9SAndroid Build Coastguard Worker               /*dropout_p=*/multihead_attn_module->options.dropout(),
3625*da0073e9SAndroid Build Coastguard Worker               /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3626*da0073e9SAndroid Build Coastguard Worker               /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3627*da0073e9SAndroid Build Coastguard Worker               .training(multihead_attn_module->is_training())
3628*da0073e9SAndroid Build Coastguard Worker               .key_padding_mask(key_padding_mask_tensor)
3629*da0073e9SAndroid Build Coastguard Worker               .need_weights(true)
3630*da0073e9SAndroid Build Coastguard Worker               .attn_mask(attn_mask_tensor)
3631*da0073e9SAndroid Build Coastguard Worker               .static_k(saved_k_tensor)
3632*da0073e9SAndroid Build Coastguard Worker               .static_v(saved_v_tensor)
3633*da0073e9SAndroid Build Coastguard Worker               .average_attn_weights(average_attn_weights));
3634*da0073e9SAndroid Build Coastguard Worker     } else {
3635*da0073e9SAndroid Build Coastguard Worker       std::tie(result, result_weight) = F::multi_head_attention_forward(
3636*da0073e9SAndroid Build Coastguard Worker           _Q,
3637*da0073e9SAndroid Build Coastguard Worker           _K,
3638*da0073e9SAndroid Build Coastguard Worker           _V,
3639*da0073e9SAndroid Build Coastguard Worker           F::MultiheadAttentionForwardFuncOptions(
3640*da0073e9SAndroid Build Coastguard Worker               /*embed_dim_to_check=*/d_model,
3641*da0073e9SAndroid Build Coastguard Worker               /*num_heads=*/nheads,
3642*da0073e9SAndroid Build Coastguard Worker               /*in_proj_weight=*/{},
3643*da0073e9SAndroid Build Coastguard Worker               /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3644*da0073e9SAndroid Build Coastguard Worker               /*bias_k=*/multihead_attn_module->bias_k,
3645*da0073e9SAndroid Build Coastguard Worker               /*bias_v=*/multihead_attn_module->bias_v,
3646*da0073e9SAndroid Build Coastguard Worker               /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3647*da0073e9SAndroid Build Coastguard Worker               /*dropout_p=*/multihead_attn_module->options.dropout(),
3648*da0073e9SAndroid Build Coastguard Worker               /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3649*da0073e9SAndroid Build Coastguard Worker               /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3650*da0073e9SAndroid Build Coastguard Worker               .training(multihead_attn_module->is_training())
3651*da0073e9SAndroid Build Coastguard Worker               .key_padding_mask(key_padding_mask_tensor)
3652*da0073e9SAndroid Build Coastguard Worker               .need_weights(true)
3653*da0073e9SAndroid Build Coastguard Worker               .attn_mask(attn_mask_tensor)
3654*da0073e9SAndroid Build Coastguard Worker               .use_separate_proj_weight(true)
3655*da0073e9SAndroid Build Coastguard Worker               .q_proj_weight(multihead_attn_module->q_proj_weight)
3656*da0073e9SAndroid Build Coastguard Worker               .k_proj_weight(multihead_attn_module->k_proj_weight)
3657*da0073e9SAndroid Build Coastguard Worker               .v_proj_weight(multihead_attn_module->v_proj_weight)
3658*da0073e9SAndroid Build Coastguard Worker               .static_k(saved_k_tensor)
3659*da0073e9SAndroid Build Coastguard Worker               .static_v(saved_v_tensor)
3660*da0073e9SAndroid Build Coastguard Worker               .average_attn_weights(average_attn_weights));
3661*da0073e9SAndroid Build Coastguard Worker     }
3662*da0073e9SAndroid Build Coastguard Worker     result = result.squeeze(0).detach();
3663*da0073e9SAndroid Build Coastguard Worker     torch::Tensor q_proj_weight;
3664*da0073e9SAndroid Build Coastguard Worker     torch::Tensor k_proj_weight;
3665*da0073e9SAndroid Build Coastguard Worker     torch::Tensor v_proj_weight;
3666*da0073e9SAndroid Build Coastguard Worker     if (multihead_attn_module->_qkv_same_embed_dim) {
3667*da0073e9SAndroid Build Coastguard Worker       q_proj_weight =
3668*da0073e9SAndroid Build Coastguard Worker           multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model);
3669*da0073e9SAndroid Build Coastguard Worker       k_proj_weight = multihead_attn_module->in_proj_weight.slice(
3670*da0073e9SAndroid Build Coastguard Worker           /*dim=*/0, d_model, (d_model * 2));
3671*da0073e9SAndroid Build Coastguard Worker       v_proj_weight =
3672*da0073e9SAndroid Build Coastguard Worker           multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2));
3673*da0073e9SAndroid Build Coastguard Worker     } else {
3674*da0073e9SAndroid Build Coastguard Worker       q_proj_weight = multihead_attn_module->q_proj_weight;
3675*da0073e9SAndroid Build Coastguard Worker       k_proj_weight = multihead_attn_module->k_proj_weight;
3676*da0073e9SAndroid Build Coastguard Worker       v_proj_weight = multihead_attn_module->v_proj_weight;
3677*da0073e9SAndroid Build Coastguard Worker     }
3678*da0073e9SAndroid Build Coastguard Worker     auto Q_fc =
3679*da0073e9SAndroid Build Coastguard Worker         _fc(Q,
3680*da0073e9SAndroid Build Coastguard Worker             q_proj_weight,
3681*da0073e9SAndroid Build Coastguard Worker             multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model));
3682*da0073e9SAndroid Build Coastguard Worker     auto K_fc =
3683*da0073e9SAndroid Build Coastguard Worker         _fc(K,
3684*da0073e9SAndroid Build Coastguard Worker             k_proj_weight,
3685*da0073e9SAndroid Build Coastguard Worker             multihead_attn_module->in_proj_bias.slice(
3686*da0073e9SAndroid Build Coastguard Worker                 /*dim=*/0, d_model, (d_model * 2)));
3687*da0073e9SAndroid Build Coastguard Worker     auto V_fc = _fc(
3688*da0073e9SAndroid Build Coastguard Worker         V,
3689*da0073e9SAndroid Build Coastguard Worker         v_proj_weight,
3690*da0073e9SAndroid Build Coastguard Worker         multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2)));
3691*da0073e9SAndroid Build Coastguard Worker 
3692*da0073e9SAndroid Build Coastguard Worker     if (add_bias_kv) {
3693*da0073e9SAndroid Build Coastguard Worker       K_fc = torch::cat(
3694*da0073e9SAndroid Build Coastguard Worker           {K_fc,
3695*da0073e9SAndroid Build Coastguard Worker            bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)},
3696*da0073e9SAndroid Build Coastguard Worker           /*dim=*/1);
3697*da0073e9SAndroid Build Coastguard Worker       V_fc = torch::cat(
3698*da0073e9SAndroid Build Coastguard Worker           {V_fc,
3699*da0073e9SAndroid Build Coastguard Worker            bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)},
3700*da0073e9SAndroid Build Coastguard Worker           /*dim=*/1);
3701*da0073e9SAndroid Build Coastguard Worker       if (attn_mask.defined()) {
3702*da0073e9SAndroid Build Coastguard Worker         attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3703*da0073e9SAndroid Build Coastguard Worker       }
3704*da0073e9SAndroid Build Coastguard Worker       if (key_padding_mask.defined()) {
3705*da0073e9SAndroid Build Coastguard Worker         key_padding_mask = torch::cat(
3706*da0073e9SAndroid Build Coastguard Worker             {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3707*da0073e9SAndroid Build Coastguard Worker             /*dim=*/1);
3708*da0073e9SAndroid Build Coastguard Worker       }
3709*da0073e9SAndroid Build Coastguard Worker       dims[1] += 1;
3710*da0073e9SAndroid Build Coastguard Worker     }
3711*da0073e9SAndroid Build Coastguard Worker     const auto Q_split =
3712*da0073e9SAndroid Build Coastguard Worker         _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head);
3713*da0073e9SAndroid Build Coastguard Worker     torch::Tensor K_split;
3714*da0073e9SAndroid Build Coastguard Worker     if (saved_k.defined()) {
3715*da0073e9SAndroid Build Coastguard Worker       K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head});
3716*da0073e9SAndroid Build Coastguard Worker     } else {
3717*da0073e9SAndroid Build Coastguard Worker       K_split = _split_heads_ref(K_fc, dims, nheads, d_head);
3718*da0073e9SAndroid Build Coastguard Worker     }
3719*da0073e9SAndroid Build Coastguard Worker     torch::Tensor V_split;
3720*da0073e9SAndroid Build Coastguard Worker     if (saved_v.defined()) {
3721*da0073e9SAndroid Build Coastguard Worker       V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head});
3722*da0073e9SAndroid Build Coastguard Worker     } else {
3723*da0073e9SAndroid Build Coastguard Worker       V_split = _split_heads_ref(V_fc, dims, nheads, d_head);
3724*da0073e9SAndroid Build Coastguard Worker     }
3725*da0073e9SAndroid Build Coastguard Worker     if (add_zero_attn) {
3726*da0073e9SAndroid Build Coastguard Worker       dims[1] += 1;
3727*da0073e9SAndroid Build Coastguard Worker       K_split = torch::cat(
3728*da0073e9SAndroid Build Coastguard Worker           {K_split,
3729*da0073e9SAndroid Build Coastguard Worker            torch::zeros(
3730*da0073e9SAndroid Build Coastguard Worker                {K_split.size(0), K_split.size(1), 1, K_split.size(3)})},
3731*da0073e9SAndroid Build Coastguard Worker           /*dim=*/2);
3732*da0073e9SAndroid Build Coastguard Worker       V_split = torch::cat(
3733*da0073e9SAndroid Build Coastguard Worker           {V_split,
3734*da0073e9SAndroid Build Coastguard Worker            torch::zeros(
3735*da0073e9SAndroid Build Coastguard Worker                {V_split.size(0), V_split.size(1), 1, V_split.size(3)})},
3736*da0073e9SAndroid Build Coastguard Worker           /*dim=*/2);
3737*da0073e9SAndroid Build Coastguard Worker       if (attn_mask.defined()) {
3738*da0073e9SAndroid Build Coastguard Worker         attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3739*da0073e9SAndroid Build Coastguard Worker       }
3740*da0073e9SAndroid Build Coastguard Worker       if (key_padding_mask.defined()) {
3741*da0073e9SAndroid Build Coastguard Worker         key_padding_mask = torch::cat(
3742*da0073e9SAndroid Build Coastguard Worker             {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3743*da0073e9SAndroid Build Coastguard Worker             /*dim=*/1);
3744*da0073e9SAndroid Build Coastguard Worker       }
3745*da0073e9SAndroid Build Coastguard Worker     }
3746*da0073e9SAndroid Build Coastguard Worker     auto [attn_heads, ref_attn_weight] = _scaled_dot_attn_ref(
3747*da0073e9SAndroid Build Coastguard Worker         Q_split,
3748*da0073e9SAndroid Build Coastguard Worker         K_split,
3749*da0073e9SAndroid Build Coastguard Worker         V_split,
3750*da0073e9SAndroid Build Coastguard Worker         Q_split.sizes(),
3751*da0073e9SAndroid Build Coastguard Worker         attn_mask,
3752*da0073e9SAndroid Build Coastguard Worker         key_padding_mask,
3753*da0073e9SAndroid Build Coastguard Worker         average_attn_weights);
3754*da0073e9SAndroid Build Coastguard Worker     const auto combined_attn_heads =
3755*da0073e9SAndroid Build Coastguard Worker         _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head);
3756*da0073e9SAndroid Build Coastguard Worker     auto reference =
3757*da0073e9SAndroid Build Coastguard Worker         _fc(combined_attn_heads,
3758*da0073e9SAndroid Build Coastguard Worker             multihead_attn_module->out_proj->weight,
3759*da0073e9SAndroid Build Coastguard Worker             multihead_attn_module->out_proj->bias);
3760*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(bugprone-argument-comment)
3761*da0073e9SAndroid Build Coastguard Worker     reference = torch::squeeze(reference, /*axis=*/1);
3762*da0073e9SAndroid Build Coastguard Worker 
3763*da0073e9SAndroid Build Coastguard Worker     // result = reference
3764*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(result.sizes(), std::vector<int64_t>({batch_sz, d_model}));
3765*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
3766*da0073e9SAndroid Build Coastguard Worker         torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true));
3767*da0073e9SAndroid Build Coastguard Worker 
3768*da0073e9SAndroid Build Coastguard Worker     // result_weight = ref_attn_weight
3769*da0073e9SAndroid Build Coastguard Worker     result_weight = result_weight.detach();
3770*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes());
3771*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
3772*da0073e9SAndroid Build Coastguard Worker         result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true));
3773*da0073e9SAndroid Build Coastguard Worker   }
3774*da0073e9SAndroid Build Coastguard Worker }
3775*da0073e9SAndroid Build Coastguard Worker } // namespace detail
3776*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,MultiheadAttention)3777*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, MultiheadAttention) {
3778*da0073e9SAndroid Build Coastguard Worker   using namespace ::detail;
3779*da0073e9SAndroid Build Coastguard Worker 
3780*da0073e9SAndroid Build Coastguard Worker   for (auto average_attn_weights : {false, true}) {
3781*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_add_zero_attn
3782*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3783*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/false,
3784*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/false,
3785*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/true,
3786*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/false,
3787*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3788*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3789*da0073e9SAndroid Build Coastguard Worker 
3790*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_add_bias_kv
3791*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3792*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/false,
3793*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/true,
3794*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/false,
3795*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/false,
3796*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3797*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3798*da0073e9SAndroid Build Coastguard Worker 
3799*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_no_masking():
3800*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper();
3801*da0073e9SAndroid Build Coastguard Worker 
3802*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_key_padding_mask
3803*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3804*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/true,
3805*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/false,
3806*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/false,
3807*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/false,
3808*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3809*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3810*da0073e9SAndroid Build Coastguard Worker 
3811*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_saved_kv
3812*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3813*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/false,
3814*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/false,
3815*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/false,
3816*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/true,
3817*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3818*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3819*da0073e9SAndroid Build Coastguard Worker 
3820*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_add_bias_kv_zero_attn
3821*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3822*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/true,
3823*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/true,
3824*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/true,
3825*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/false,
3826*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3827*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3828*da0073e9SAndroid Build Coastguard Worker 
3829*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_all_arguments1
3830*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3831*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/true,
3832*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/false,
3833*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/true,
3834*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/true,
3835*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/false,
3836*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3837*da0073e9SAndroid Build Coastguard Worker 
3838*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
3839*da0073e9SAndroid Build Coastguard Worker         // test_multihead_attn_all_arguments2
3840*da0073e9SAndroid Build Coastguard Worker         _multihead_attn_test_helper(
3841*da0073e9SAndroid Build Coastguard Worker             /*add_key_padding_mask=*/true,
3842*da0073e9SAndroid Build Coastguard Worker             /*add_bias_kv=*/true,
3843*da0073e9SAndroid Build Coastguard Worker             /*add_zero_attn=*/true,
3844*da0073e9SAndroid Build Coastguard Worker             /*saved_kv=*/true,
3845*da0073e9SAndroid Build Coastguard Worker             /*same_embed_dim=*/false,
3846*da0073e9SAndroid Build Coastguard Worker             /*average_attn_weights=*/average_attn_weights),
3847*da0073e9SAndroid Build Coastguard Worker         "bias cannot be added to static key");
3848*da0073e9SAndroid Build Coastguard Worker 
3849*da0073e9SAndroid Build Coastguard Worker     // test_multihead_attn_all_arguments3
3850*da0073e9SAndroid Build Coastguard Worker     _multihead_attn_test_helper(
3851*da0073e9SAndroid Build Coastguard Worker         /*add_key_padding_mask=*/true,
3852*da0073e9SAndroid Build Coastguard Worker         /*add_bias_kv=*/false,
3853*da0073e9SAndroid Build Coastguard Worker         /*add_zero_attn=*/true,
3854*da0073e9SAndroid Build Coastguard Worker         /*saved_kv=*/true,
3855*da0073e9SAndroid Build Coastguard Worker         /*same_embed_dim=*/true,
3856*da0073e9SAndroid Build Coastguard Worker         /*average_attn_weights=*/average_attn_weights);
3857*da0073e9SAndroid Build Coastguard Worker   }
3858*da0073e9SAndroid Build Coastguard Worker }
3859*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintIdentity)3860*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintIdentity) {
3861*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
3862*da0073e9SAndroid Build Coastguard Worker }
3863*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintFlatten)3864*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintFlatten) {
3865*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)");
3866*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
3867*da0073e9SAndroid Build Coastguard Worker       c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
3868*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Flatten(start_dim=2, end_dim=4)");
3869*da0073e9SAndroid Build Coastguard Worker }
3870*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintUnflatten)3871*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintUnflatten) {
3872*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
3873*da0073e9SAndroid Build Coastguard Worker       c10::str(Unflatten(UnflattenOptions(0, {2, 2}))),
3874*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Unflatten(dim=0, unflattened_size={2, 2})");
3875*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
3876*da0073e9SAndroid Build Coastguard Worker       c10::str(Unflatten(UnflattenOptions(
3877*da0073e9SAndroid Build Coastguard Worker           "B",
3878*da0073e9SAndroid Build Coastguard Worker           {std::pair<std::string, int64_t>{"B1", 2},
3879*da0073e9SAndroid Build Coastguard Worker            std::pair<std::string, int64_t>{"B2", 2}}))),
3880*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Unflatten(dim=\"B\", unflattened_size={{\"B1\", 2}, {\"B2\", 2}})");
3881*da0073e9SAndroid Build Coastguard Worker }
3882*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReflectionPad1d)3883*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReflectionPad1d) {
3884*da0073e9SAndroid Build Coastguard Worker   {
3885*da0073e9SAndroid Build Coastguard Worker     ReflectionPad1d m(ReflectionPad1dOptions(2));
3886*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3887*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3888*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3889*da0073e9SAndroid Build Coastguard Worker         {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}},
3890*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3891*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3892*da0073e9SAndroid Build Coastguard Worker   }
3893*da0073e9SAndroid Build Coastguard Worker   {
3894*da0073e9SAndroid Build Coastguard Worker     ReflectionPad1d m(ReflectionPad1dOptions({3, 1}));
3895*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3896*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3897*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3898*da0073e9SAndroid Build Coastguard Worker         {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}},
3899*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3900*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3901*da0073e9SAndroid Build Coastguard Worker   }
3902*da0073e9SAndroid Build Coastguard Worker }
3903*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReflectionPad2d)3904*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReflectionPad2d) {
3905*da0073e9SAndroid Build Coastguard Worker   {
3906*da0073e9SAndroid Build Coastguard Worker     ReflectionPad2d m(ReflectionPad2dOptions(2));
3907*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3908*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3909*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3910*da0073e9SAndroid Build Coastguard Worker         {{{{8., 7., 6., 7., 8., 7., 6.},
3911*da0073e9SAndroid Build Coastguard Worker            {5., 4., 3., 4., 5., 4., 3.},
3912*da0073e9SAndroid Build Coastguard Worker            {2., 1., 0., 1., 2., 1., 0.},
3913*da0073e9SAndroid Build Coastguard Worker            {5., 4., 3., 4., 5., 4., 3.},
3914*da0073e9SAndroid Build Coastguard Worker            {8., 7., 6., 7., 8., 7., 6.},
3915*da0073e9SAndroid Build Coastguard Worker            {5., 4., 3., 4., 5., 4., 3.},
3916*da0073e9SAndroid Build Coastguard Worker            {2., 1., 0., 1., 2., 1., 0.}}}},
3917*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3918*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3919*da0073e9SAndroid Build Coastguard Worker   }
3920*da0073e9SAndroid Build Coastguard Worker   {
3921*da0073e9SAndroid Build Coastguard Worker     ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0}));
3922*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3923*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3924*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3925*da0073e9SAndroid Build Coastguard Worker         {{{{7., 6., 7., 8., 7.},
3926*da0073e9SAndroid Build Coastguard Worker            {4., 3., 4., 5., 4.},
3927*da0073e9SAndroid Build Coastguard Worker            {1., 0., 1., 2., 1.},
3928*da0073e9SAndroid Build Coastguard Worker            {4., 3., 4., 5., 4.},
3929*da0073e9SAndroid Build Coastguard Worker            {7., 6., 7., 8., 7.}}}},
3930*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3931*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3932*da0073e9SAndroid Build Coastguard Worker   }
3933*da0073e9SAndroid Build Coastguard Worker }
3934*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReflectionPad3d)3935*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReflectionPad3d) {
3936*da0073e9SAndroid Build Coastguard Worker   {
3937*da0073e9SAndroid Build Coastguard Worker     ReflectionPad3d m(ReflectionPad3dOptions(1));
3938*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
3939*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3940*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3941*da0073e9SAndroid Build Coastguard Worker         {{{{{7., 6., 7., 6.},
3942*da0073e9SAndroid Build Coastguard Worker             {5., 4., 5., 4.},
3943*da0073e9SAndroid Build Coastguard Worker             {7., 6., 7., 6.},
3944*da0073e9SAndroid Build Coastguard Worker             {5., 4., 5., 4.}},
3945*da0073e9SAndroid Build Coastguard Worker            {{3., 2., 3., 2.},
3946*da0073e9SAndroid Build Coastguard Worker             {1., 0., 1., 0.},
3947*da0073e9SAndroid Build Coastguard Worker             {3., 2., 3., 2.},
3948*da0073e9SAndroid Build Coastguard Worker             {1., 0., 1., 0.}},
3949*da0073e9SAndroid Build Coastguard Worker            {{7., 6., 7., 6.},
3950*da0073e9SAndroid Build Coastguard Worker             {5., 4., 5., 4.},
3951*da0073e9SAndroid Build Coastguard Worker             {7., 6., 7., 6.},
3952*da0073e9SAndroid Build Coastguard Worker             {5., 4., 5., 4.}},
3953*da0073e9SAndroid Build Coastguard Worker            {{3., 2., 3., 2.},
3954*da0073e9SAndroid Build Coastguard Worker             {1., 0., 1., 0.},
3955*da0073e9SAndroid Build Coastguard Worker             {3., 2., 3., 2.},
3956*da0073e9SAndroid Build Coastguard Worker             {1., 0., 1., 0.}}}}},
3957*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3958*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3959*da0073e9SAndroid Build Coastguard Worker   }
3960*da0073e9SAndroid Build Coastguard Worker   {
3961*da0073e9SAndroid Build Coastguard Worker     ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2}));
3962*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2});
3963*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3964*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3965*da0073e9SAndroid Build Coastguard Worker         {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3966*da0073e9SAndroid Build Coastguard Worker            {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}},
3967*da0073e9SAndroid Build Coastguard Worker            {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3968*da0073e9SAndroid Build Coastguard Worker            {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3969*da0073e9SAndroid Build Coastguard Worker            {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}},
3970*da0073e9SAndroid Build Coastguard Worker            {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3971*da0073e9SAndroid Build Coastguard Worker            {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}},
3972*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3973*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 3, 3}));
3974*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3975*da0073e9SAndroid Build Coastguard Worker   }
3976*da0073e9SAndroid Build Coastguard Worker }
TEST_F(ModulesTest,ReplicationPad1d)3977*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReplicationPad1d) {
3978*da0073e9SAndroid Build Coastguard Worker   {
3979*da0073e9SAndroid Build Coastguard Worker     ReplicationPad1d m(ReplicationPad1dOptions(2));
3980*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3981*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3982*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3983*da0073e9SAndroid Build Coastguard Worker         {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}},
3984*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3985*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3986*da0073e9SAndroid Build Coastguard Worker   }
3987*da0073e9SAndroid Build Coastguard Worker   {
3988*da0073e9SAndroid Build Coastguard Worker     ReplicationPad1d m(ReplicationPad1dOptions({3, 1}));
3989*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3990*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
3991*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
3992*da0073e9SAndroid Build Coastguard Worker         {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}},
3993*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
3994*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
3995*da0073e9SAndroid Build Coastguard Worker   }
3996*da0073e9SAndroid Build Coastguard Worker }
3997*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReplicationPad2d)3998*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReplicationPad2d) {
3999*da0073e9SAndroid Build Coastguard Worker   {
4000*da0073e9SAndroid Build Coastguard Worker     ReplicationPad2d m(ReplicationPad2dOptions(2));
4001*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4002*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4003*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4004*da0073e9SAndroid Build Coastguard Worker         {{{{0., 0., 0., 1., 2., 2., 2.},
4005*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 1., 2., 2., 2.},
4006*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 1., 2., 2., 2.},
4007*da0073e9SAndroid Build Coastguard Worker            {3., 3., 3., 4., 5., 5., 5.},
4008*da0073e9SAndroid Build Coastguard Worker            {6., 6., 6., 7., 8., 8., 8.},
4009*da0073e9SAndroid Build Coastguard Worker            {6., 6., 6., 7., 8., 8., 8.},
4010*da0073e9SAndroid Build Coastguard Worker            {6., 6., 6., 7., 8., 8., 8.}}}},
4011*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4012*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4013*da0073e9SAndroid Build Coastguard Worker   }
4014*da0073e9SAndroid Build Coastguard Worker   {
4015*da0073e9SAndroid Build Coastguard Worker     ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0}));
4016*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4017*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4018*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4019*da0073e9SAndroid Build Coastguard Worker         {{{{0., 0., 1., 2., 2.},
4020*da0073e9SAndroid Build Coastguard Worker            {0., 0., 1., 2., 2.},
4021*da0073e9SAndroid Build Coastguard Worker            {0., 0., 1., 2., 2.},
4022*da0073e9SAndroid Build Coastguard Worker            {3., 3., 4., 5., 5.},
4023*da0073e9SAndroid Build Coastguard Worker            {6., 6., 7., 8., 8.}}}},
4024*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4025*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4026*da0073e9SAndroid Build Coastguard Worker   }
4027*da0073e9SAndroid Build Coastguard Worker }
4028*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ReplicationPad3d)4029*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ReplicationPad3d) {
4030*da0073e9SAndroid Build Coastguard Worker   {
4031*da0073e9SAndroid Build Coastguard Worker     ReplicationPad3d m(ReplicationPad3dOptions(1));
4032*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4033*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4034*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4035*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 0., 1., 1.},
4036*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 1.},
4037*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3.},
4038*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3.}},
4039*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 1., 1.},
4040*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 1.},
4041*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3.},
4042*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3.}},
4043*da0073e9SAndroid Build Coastguard Worker            {{4., 4., 5., 5.},
4044*da0073e9SAndroid Build Coastguard Worker             {4., 4., 5., 5.},
4045*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7.},
4046*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7.}},
4047*da0073e9SAndroid Build Coastguard Worker            {{4., 4., 5., 5.},
4048*da0073e9SAndroid Build Coastguard Worker             {4., 4., 5., 5.},
4049*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7.},
4050*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7.}}}}},
4051*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4052*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4053*da0073e9SAndroid Build Coastguard Worker   }
4054*da0073e9SAndroid Build Coastguard Worker   {
4055*da0073e9SAndroid Build Coastguard Worker     ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}));
4056*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4057*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4058*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4059*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 0., 1., 1., 1.},
4060*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 1., 1.},
4061*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.},
4062*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.},
4063*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.}},
4064*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 1., 1., 1.},
4065*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 1., 1.},
4066*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.},
4067*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.},
4068*da0073e9SAndroid Build Coastguard Worker             {2., 2., 3., 3., 3.}},
4069*da0073e9SAndroid Build Coastguard Worker            {{4., 4., 5., 5., 5.},
4070*da0073e9SAndroid Build Coastguard Worker             {4., 4., 5., 5., 5.},
4071*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4072*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4073*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.}},
4074*da0073e9SAndroid Build Coastguard Worker            {{4., 4., 5., 5., 5.},
4075*da0073e9SAndroid Build Coastguard Worker             {4., 4., 5., 5., 5.},
4076*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4077*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4078*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.}},
4079*da0073e9SAndroid Build Coastguard Worker            {{4., 4., 5., 5., 5.},
4080*da0073e9SAndroid Build Coastguard Worker             {4., 4., 5., 5., 5.},
4081*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4082*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.},
4083*da0073e9SAndroid Build Coastguard Worker             {6., 6., 7., 7., 7.}}}}},
4084*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4085*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4086*da0073e9SAndroid Build Coastguard Worker   }
4087*da0073e9SAndroid Build Coastguard Worker }
4088*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ZeroPad1d)4089*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ZeroPad1d) {
4090*da0073e9SAndroid Build Coastguard Worker   {
4091*da0073e9SAndroid Build Coastguard Worker     ZeroPad1d m(ZeroPad1dOptions(2));
4092*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4093*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4094*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4095*da0073e9SAndroid Build Coastguard Worker         {{{0., 0., 0., 1., 2., 3., 0., 0.}, {0., 0., 4., 5., 6., 7., 0., 0.}}},
4096*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4097*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4098*da0073e9SAndroid Build Coastguard Worker   }
4099*da0073e9SAndroid Build Coastguard Worker   {
4100*da0073e9SAndroid Build Coastguard Worker     ZeroPad1d m(ZeroPad1dOptions({3, 1}));
4101*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4102*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4103*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4104*da0073e9SAndroid Build Coastguard Worker         {{{0., 0., 0., 0., 1., 2., 0.}, {0., 0., 0., 3., 4., 5., 0.}}},
4105*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4106*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4107*da0073e9SAndroid Build Coastguard Worker   }
4108*da0073e9SAndroid Build Coastguard Worker }
4109*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ZeroPad2d)4110*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ZeroPad2d) {
4111*da0073e9SAndroid Build Coastguard Worker   {
4112*da0073e9SAndroid Build Coastguard Worker     ZeroPad2d m(ZeroPad2dOptions(2));
4113*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4114*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4115*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4116*da0073e9SAndroid Build Coastguard Worker         {{{{0., 0., 0., 0., 0., 0., 0.},
4117*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 0., 0., 0., 0.},
4118*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 1., 2., 0., 0.},
4119*da0073e9SAndroid Build Coastguard Worker            {0., 0., 3., 4., 5., 0., 0.},
4120*da0073e9SAndroid Build Coastguard Worker            {0., 0., 6., 7., 8., 0., 0.},
4121*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 0., 0., 0., 0.},
4122*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 0., 0., 0., 0.}}}},
4123*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4124*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4125*da0073e9SAndroid Build Coastguard Worker   }
4126*da0073e9SAndroid Build Coastguard Worker   {
4127*da0073e9SAndroid Build Coastguard Worker     ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0}));
4128*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4129*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4130*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4131*da0073e9SAndroid Build Coastguard Worker         {{{{0., 0., 0., 0., 0.},
4132*da0073e9SAndroid Build Coastguard Worker            {0., 0., 0., 0., 0.},
4133*da0073e9SAndroid Build Coastguard Worker            {0., 0., 1., 2., 0.},
4134*da0073e9SAndroid Build Coastguard Worker            {0., 3., 4., 5., 0.},
4135*da0073e9SAndroid Build Coastguard Worker            {0., 6., 7., 8., 0.}}}},
4136*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4137*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4138*da0073e9SAndroid Build Coastguard Worker   }
4139*da0073e9SAndroid Build Coastguard Worker }
4140*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ZeroPad3d)4141*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ZeroPad3d) {
4142*da0073e9SAndroid Build Coastguard Worker   {
4143*da0073e9SAndroid Build Coastguard Worker     ZeroPad3d m(ZeroPad3dOptions(1));
4144*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4145*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4146*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4147*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 0., 0., 0.},
4148*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.},
4149*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.},
4150*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.}},
4151*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0.},
4152*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 0.},
4153*da0073e9SAndroid Build Coastguard Worker             {0., 2., 3., 0.},
4154*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.}},
4155*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0.},
4156*da0073e9SAndroid Build Coastguard Worker             {0., 4., 5., 0.},
4157*da0073e9SAndroid Build Coastguard Worker             {0., 6., 7., 0.},
4158*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.}},
4159*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0.},
4160*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.},
4161*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.},
4162*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0.}}}}},
4163*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4164*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4165*da0073e9SAndroid Build Coastguard Worker   }
4166*da0073e9SAndroid Build Coastguard Worker   {
4167*da0073e9SAndroid Build Coastguard Worker     ZeroPad3d m(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}));
4168*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4169*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4170*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4171*da0073e9SAndroid Build Coastguard Worker         {{{{{0., 0., 0., 0., 0.},
4172*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4173*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4174*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4175*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.}},
4176*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0., 0.},
4177*da0073e9SAndroid Build Coastguard Worker             {0., 0., 1., 0., 0.},
4178*da0073e9SAndroid Build Coastguard Worker             {0., 2., 3., 0., 0.},
4179*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4180*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.}},
4181*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0., 0.},
4182*da0073e9SAndroid Build Coastguard Worker             {0., 4., 5., 0., 0.},
4183*da0073e9SAndroid Build Coastguard Worker             {0., 6., 7., 0., 0.},
4184*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4185*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.}},
4186*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0., 0.},
4187*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4188*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4189*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4190*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.}},
4191*da0073e9SAndroid Build Coastguard Worker            {{0., 0., 0., 0., 0.},
4192*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4193*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4194*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.},
4195*da0073e9SAndroid Build Coastguard Worker             {0., 0., 0., 0., 0.}}}}},
4196*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4197*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4198*da0073e9SAndroid Build Coastguard Worker   }
4199*da0073e9SAndroid Build Coastguard Worker }
4200*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConstantPad1d)4201*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConstantPad1d) {
4202*da0073e9SAndroid Build Coastguard Worker   {
4203*da0073e9SAndroid Build Coastguard Worker     ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
4204*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4205*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4206*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4207*da0073e9SAndroid Build Coastguard Worker         {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000},
4208*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}},
4209*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4210*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4211*da0073e9SAndroid Build Coastguard Worker   }
4212*da0073e9SAndroid Build Coastguard Worker   {
4213*da0073e9SAndroid Build Coastguard Worker     ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5));
4214*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4215*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4216*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4217*da0073e9SAndroid Build Coastguard Worker         {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000},
4218*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}},
4219*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4220*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4221*da0073e9SAndroid Build Coastguard Worker   }
4222*da0073e9SAndroid Build Coastguard Worker }
4223*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConstantPad2d)4224*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConstantPad2d) {
4225*da0073e9SAndroid Build Coastguard Worker   {
4226*da0073e9SAndroid Build Coastguard Worker     ConstantPad2d m(ConstantPad2dOptions(2, 3.5));
4227*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4228*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4229*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4230*da0073e9SAndroid Build Coastguard Worker         {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4231*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4232*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4233*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4234*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4235*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4236*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4237*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4238*da0073e9SAndroid Build Coastguard Worker   }
4239*da0073e9SAndroid Build Coastguard Worker   {
4240*da0073e9SAndroid Build Coastguard Worker     ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5));
4241*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4242*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4243*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4244*da0073e9SAndroid Build Coastguard Worker         {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4245*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4246*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 0.0000, 1.0000},
4247*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 2.0000, 3.0000},
4248*da0073e9SAndroid Build Coastguard Worker           {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4249*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4250*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4251*da0073e9SAndroid Build Coastguard Worker   }
4252*da0073e9SAndroid Build Coastguard Worker }
4253*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,ConstantPad3d)4254*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, ConstantPad3d) {
4255*da0073e9SAndroid Build Coastguard Worker   {
4256*da0073e9SAndroid Build Coastguard Worker     ConstantPad3d m(ConstantPad3dOptions(1, 3.5));
4257*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4258*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4259*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4260*da0073e9SAndroid Build Coastguard Worker         {{{{{3.5000, 3.5000, 3.5000, 3.5000},
4261*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000},
4262*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000},
4263*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000}},
4264*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000},
4265*da0073e9SAndroid Build Coastguard Worker             {3.5000, 0.0000, 1.0000, 3.5000},
4266*da0073e9SAndroid Build Coastguard Worker             {3.5000, 2.0000, 3.0000, 3.5000},
4267*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000}},
4268*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000},
4269*da0073e9SAndroid Build Coastguard Worker             {3.5000, 4.0000, 5.0000, 3.5000},
4270*da0073e9SAndroid Build Coastguard Worker             {3.5000, 6.0000, 7.0000, 3.5000},
4271*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000}},
4272*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000},
4273*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000},
4274*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000},
4275*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000}}}}},
4276*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4277*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4278*da0073e9SAndroid Build Coastguard Worker   }
4279*da0073e9SAndroid Build Coastguard Worker   {
4280*da0073e9SAndroid Build Coastguard Worker     ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5));
4281*da0073e9SAndroid Build Coastguard Worker     auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4282*da0073e9SAndroid Build Coastguard Worker     auto output = m(input);
4283*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
4284*da0073e9SAndroid Build Coastguard Worker         {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4285*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4286*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4287*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4288*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4289*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4290*da0073e9SAndroid Build Coastguard Worker             {3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4291*da0073e9SAndroid Build Coastguard Worker             {3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4292*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4293*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4294*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4295*da0073e9SAndroid Build Coastguard Worker             {3.5000, 4.0000, 5.0000, 3.5000, 3.5000},
4296*da0073e9SAndroid Build Coastguard Worker             {3.5000, 6.0000, 7.0000, 3.5000, 3.5000},
4297*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4298*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4299*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4300*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4301*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4302*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4303*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4304*da0073e9SAndroid Build Coastguard Worker            {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4305*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4306*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4307*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4308*da0073e9SAndroid Build Coastguard Worker             {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}},
4309*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
4310*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.allclose(expected));
4311*da0073e9SAndroid Build Coastguard Worker   }
4312*da0073e9SAndroid Build Coastguard Worker }
4313*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,CrossMapLRN2d)4314*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, CrossMapLRN2d) {
4315*da0073e9SAndroid Build Coastguard Worker   /// size 3, default options
4316*da0073e9SAndroid Build Coastguard Worker   auto input =
4317*da0073e9SAndroid Build Coastguard Worker       torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true);
4318*da0073e9SAndroid Build Coastguard Worker   auto expected = torch::tensor(
4319*da0073e9SAndroid Build Coastguard Worker       {{{{0.00000000, 0.99997497, 1.99980010},
4320*da0073e9SAndroid Build Coastguard Worker          {2.99932500, 3.99840070, 4.99687700},
4321*da0073e9SAndroid Build Coastguard Worker          {5.99460600, 6.99143740, 7.98722360}}}},
4322*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4323*da0073e9SAndroid Build Coastguard Worker   auto grad_expected = torch::tensor(
4324*da0073e9SAndroid Build Coastguard Worker       {{{{1.00000000, 0.99992496, 0.99970007},
4325*da0073e9SAndroid Build Coastguard Worker          {0.99932520, 0.99880093, 0.99812720},
4326*da0073e9SAndroid Build Coastguard Worker          {0.99730474, 0.99633380, 0.99521490}}}},
4327*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4328*da0073e9SAndroid Build Coastguard Worker   auto crossmaplrn2d = CrossMapLRN2d(3);
4329*da0073e9SAndroid Build Coastguard Worker   auto output = crossmaplrn2d(input);
4330*da0073e9SAndroid Build Coastguard Worker   output.sum().backward();
4331*da0073e9SAndroid Build Coastguard Worker 
4332*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.grad().allclose(grad_expected));
4333*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
4334*da0073e9SAndroid Build Coastguard Worker 
4335*da0073e9SAndroid Build Coastguard Worker   /// size change
4336*da0073e9SAndroid Build Coastguard Worker   crossmaplrn2d =
4337*da0073e9SAndroid Build Coastguard Worker       CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1));
4338*da0073e9SAndroid Build Coastguard Worker   output = crossmaplrn2d(input);
4339*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
4340*da0073e9SAndroid Build Coastguard Worker       {{{{0.00000000, 0.99998120, 1.99985000},
4341*da0073e9SAndroid Build Coastguard Worker          {2.99949400, 3.99880050, 4.99765800},
4342*da0073e9SAndroid Build Coastguard Worker          {5.99595300, 6.99357600, 7.99041300}}}},
4343*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4344*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
4345*da0073e9SAndroid Build Coastguard Worker 
4346*da0073e9SAndroid Build Coastguard Worker   /// alpha change
4347*da0073e9SAndroid Build Coastguard Worker   crossmaplrn2d =
4348*da0073e9SAndroid Build Coastguard Worker       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1));
4349*da0073e9SAndroid Build Coastguard Worker   output = crossmaplrn2d(input);
4350*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
4351*da0073e9SAndroid Build Coastguard Worker       {{{{0.00000000, 0.99975010, 1.99800230},
4352*da0073e9SAndroid Build Coastguard Worker          {2.99326750, 3.98407440, 4.96897600},
4353*da0073e9SAndroid Build Coastguard Worker          {5.94656100, 6.91545720, 7.87434340}}}},
4354*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4355*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
4356*da0073e9SAndroid Build Coastguard Worker 
4357*da0073e9SAndroid Build Coastguard Worker   /// beta change
4358*da0073e9SAndroid Build Coastguard Worker   crossmaplrn2d =
4359*da0073e9SAndroid Build Coastguard Worker       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1));
4360*da0073e9SAndroid Build Coastguard Worker   output = crossmaplrn2d(input);
4361*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
4362*da0073e9SAndroid Build Coastguard Worker       {{{{0.00000000, 0.99996830, 1.99974680},
4363*da0073e9SAndroid Build Coastguard Worker          {2.99914500, 3.99797440, 4.99604460},
4364*da0073e9SAndroid Build Coastguard Worker          {5.99316840, 6.98915600, 7.98382000}}}},
4365*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4366*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
4367*da0073e9SAndroid Build Coastguard Worker 
4368*da0073e9SAndroid Build Coastguard Worker   /// k change
4369*da0073e9SAndroid Build Coastguard Worker   crossmaplrn2d =
4370*da0073e9SAndroid Build Coastguard Worker       CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2));
4371*da0073e9SAndroid Build Coastguard Worker   output = crossmaplrn2d(input);
4372*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor(
4373*da0073e9SAndroid Build Coastguard Worker       {{{{0.00000000, 0.59459610, 1.18914770},
4374*da0073e9SAndroid Build Coastguard Worker          {1.78361000, 2.37793870, 2.97208900},
4375*da0073e9SAndroid Build Coastguard Worker          {3.56601700, 4.15967700, 4.75302650}}}},
4376*da0073e9SAndroid Build Coastguard Worker       torch::kFloat32);
4377*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output.allclose(expected));
4378*da0073e9SAndroid Build Coastguard Worker }
4379*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,RNNCell)4380*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, RNNCell) {
4381*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
4382*da0073e9SAndroid Build Coastguard Worker   auto rnn = RNNCell(1, 2);
4383*da0073e9SAndroid Build Coastguard Worker 
4384*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({3, 1});
4385*da0073e9SAndroid Build Coastguard Worker   auto hx = torch::randn({3, 2});
4386*da0073e9SAndroid Build Coastguard Worker   auto output = rnn(input, hx);
4387*da0073e9SAndroid Build Coastguard Worker   auto expected =
4388*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}});
4389*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4390*da0073e9SAndroid Build Coastguard Worker 
4391*da0073e9SAndroid Build Coastguard Worker   output = rnn(input);
4392*da0073e9SAndroid Build Coastguard Worker   expected =
4393*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}});
4394*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4395*da0073e9SAndroid Build Coastguard Worker 
4396*da0073e9SAndroid Build Coastguard Worker   input = torch::randn({1});
4397*da0073e9SAndroid Build Coastguard Worker   hx = torch::randn({2});
4398*da0073e9SAndroid Build Coastguard Worker   output = rnn(input, hx);
4399*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor({0.2808, 0.6505});
4400*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4401*da0073e9SAndroid Build Coastguard Worker 
4402*da0073e9SAndroid Build Coastguard Worker   {
4403*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 2});
4404*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4405*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4406*da0073e9SAndroid Build Coastguard Worker         rnn(input, hx), "input has inconsistent input_size: got 2 expected 1");
4407*da0073e9SAndroid Build Coastguard Worker   }
4408*da0073e9SAndroid Build Coastguard Worker 
4409*da0073e9SAndroid Build Coastguard Worker   {
4410*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4411*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1});
4412*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4413*da0073e9SAndroid Build Coastguard Worker         rnn(input, hx),
4414*da0073e9SAndroid Build Coastguard Worker         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4415*da0073e9SAndroid Build Coastguard Worker   }
4416*da0073e9SAndroid Build Coastguard Worker 
4417*da0073e9SAndroid Build Coastguard Worker   {
4418*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1, 1, 1, 1});
4419*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4420*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4421*da0073e9SAndroid Build Coastguard Worker         rnn(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4422*da0073e9SAndroid Build Coastguard Worker   }
4423*da0073e9SAndroid Build Coastguard Worker 
4424*da0073e9SAndroid Build Coastguard Worker   {
4425*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4426*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1, 1, 1, 2});
4427*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4428*da0073e9SAndroid Build Coastguard Worker         rnn(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4429*da0073e9SAndroid Build Coastguard Worker   }
4430*da0073e9SAndroid Build Coastguard Worker }
4431*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,LSTMCell)4432*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, LSTMCell) {
4433*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
4434*da0073e9SAndroid Build Coastguard Worker   auto lstm = LSTMCell(1, 2);
4435*da0073e9SAndroid Build Coastguard Worker 
4436*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({3, 1});
4437*da0073e9SAndroid Build Coastguard Worker   auto hx = torch::randn({3, 2});
4438*da0073e9SAndroid Build Coastguard Worker   auto cx = torch::randn({3, 2});
4439*da0073e9SAndroid Build Coastguard Worker   auto output = lstm(input, std::make_tuple(hx, cx));
4440*da0073e9SAndroid Build Coastguard Worker   auto output_hx = std::get<0>(output);
4441*da0073e9SAndroid Build Coastguard Worker   auto output_cx = std::get<1>(output);
4442*da0073e9SAndroid Build Coastguard Worker   auto expected_hx =
4443*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}});
4444*da0073e9SAndroid Build Coastguard Worker   auto expected_cx =
4445*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}});
4446*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4447*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4448*da0073e9SAndroid Build Coastguard Worker 
4449*da0073e9SAndroid Build Coastguard Worker   output = lstm(input);
4450*da0073e9SAndroid Build Coastguard Worker   output_hx = std::get<0>(output);
4451*da0073e9SAndroid Build Coastguard Worker   output_cx = std::get<1>(output);
4452*da0073e9SAndroid Build Coastguard Worker   expected_hx =
4453*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}});
4454*da0073e9SAndroid Build Coastguard Worker   expected_cx =
4455*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}});
4456*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4457*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4458*da0073e9SAndroid Build Coastguard Worker 
4459*da0073e9SAndroid Build Coastguard Worker   input = torch::randn({1});
4460*da0073e9SAndroid Build Coastguard Worker   hx = torch::randn({2});
4461*da0073e9SAndroid Build Coastguard Worker   cx = torch::randn({2});
4462*da0073e9SAndroid Build Coastguard Worker   output = lstm(input, std::make_tuple(hx, cx));
4463*da0073e9SAndroid Build Coastguard Worker   output_hx = std::get<0>(output);
4464*da0073e9SAndroid Build Coastguard Worker   output_cx = std::get<1>(output);
4465*da0073e9SAndroid Build Coastguard Worker   expected_hx = torch::tensor({-0.0443, 0.1537});
4466*da0073e9SAndroid Build Coastguard Worker   expected_cx = torch::tensor({-0.1195, 0.2144});
4467*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4468*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4469*da0073e9SAndroid Build Coastguard Worker 
4470*da0073e9SAndroid Build Coastguard Worker   {
4471*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 2});
4472*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4473*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 2});
4474*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4475*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4476*da0073e9SAndroid Build Coastguard Worker         "input has inconsistent input_size: got 2 expected 1");
4477*da0073e9SAndroid Build Coastguard Worker   }
4478*da0073e9SAndroid Build Coastguard Worker 
4479*da0073e9SAndroid Build Coastguard Worker   {
4480*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4481*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1});
4482*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 2});
4483*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4484*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4485*da0073e9SAndroid Build Coastguard Worker         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4486*da0073e9SAndroid Build Coastguard Worker   }
4487*da0073e9SAndroid Build Coastguard Worker 
4488*da0073e9SAndroid Build Coastguard Worker   {
4489*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4490*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4491*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 1});
4492*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4493*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4494*da0073e9SAndroid Build Coastguard Worker         "hidden1 has inconsistent hidden_size: got 1, expected 2");
4495*da0073e9SAndroid Build Coastguard Worker   }
4496*da0073e9SAndroid Build Coastguard Worker 
4497*da0073e9SAndroid Build Coastguard Worker   {
4498*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1, 1, 1, 1});
4499*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1});
4500*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 1});
4501*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4502*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4503*da0073e9SAndroid Build Coastguard Worker         "Expected input to be 1D or 2D, got 5D instead");
4504*da0073e9SAndroid Build Coastguard Worker   }
4505*da0073e9SAndroid Build Coastguard Worker 
4506*da0073e9SAndroid Build Coastguard Worker   {
4507*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4508*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1, 1, 1, 2});
4509*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 2});
4510*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4511*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4512*da0073e9SAndroid Build Coastguard Worker         "Expected hx[0] to be 1D or 2D, got 5D instead");
4513*da0073e9SAndroid Build Coastguard Worker   }
4514*da0073e9SAndroid Build Coastguard Worker 
4515*da0073e9SAndroid Build Coastguard Worker   {
4516*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4517*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4518*da0073e9SAndroid Build Coastguard Worker     auto cx = torch::randn({3, 1, 1, 1, 2});
4519*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4520*da0073e9SAndroid Build Coastguard Worker         lstm(input, std::make_tuple(hx, cx)),
4521*da0073e9SAndroid Build Coastguard Worker         "Expected hx[1] to be 1D or 2D, got 5D instead");
4522*da0073e9SAndroid Build Coastguard Worker   }
4523*da0073e9SAndroid Build Coastguard Worker }
4524*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,GRUCell)4525*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, GRUCell) {
4526*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
4527*da0073e9SAndroid Build Coastguard Worker   auto gru = GRUCell(1, 2);
4528*da0073e9SAndroid Build Coastguard Worker 
4529*da0073e9SAndroid Build Coastguard Worker   auto input = torch::randn({3, 1});
4530*da0073e9SAndroid Build Coastguard Worker   auto hx = torch::randn({3, 2});
4531*da0073e9SAndroid Build Coastguard Worker   auto output = gru(input, hx);
4532*da0073e9SAndroid Build Coastguard Worker   auto expected =
4533*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}});
4534*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4535*da0073e9SAndroid Build Coastguard Worker 
4536*da0073e9SAndroid Build Coastguard Worker   output = gru(input);
4537*da0073e9SAndroid Build Coastguard Worker   expected =
4538*da0073e9SAndroid Build Coastguard Worker       torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}});
4539*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4540*da0073e9SAndroid Build Coastguard Worker 
4541*da0073e9SAndroid Build Coastguard Worker   input = torch::randn({1});
4542*da0073e9SAndroid Build Coastguard Worker   hx = torch::randn({2});
4543*da0073e9SAndroid Build Coastguard Worker   output = gru(input, hx);
4544*da0073e9SAndroid Build Coastguard Worker   expected = torch::tensor({-1.0058, -0.3025});
4545*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4546*da0073e9SAndroid Build Coastguard Worker 
4547*da0073e9SAndroid Build Coastguard Worker   {
4548*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 2});
4549*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4550*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4551*da0073e9SAndroid Build Coastguard Worker         gru(input, hx), "input has inconsistent input_size: got 2 expected 1");
4552*da0073e9SAndroid Build Coastguard Worker   }
4553*da0073e9SAndroid Build Coastguard Worker 
4554*da0073e9SAndroid Build Coastguard Worker   {
4555*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4556*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1});
4557*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4558*da0073e9SAndroid Build Coastguard Worker         gru(input, hx),
4559*da0073e9SAndroid Build Coastguard Worker         "hidden0 has inconsistent hidden_size: got 1, expected 2");
4560*da0073e9SAndroid Build Coastguard Worker   }
4561*da0073e9SAndroid Build Coastguard Worker 
4562*da0073e9SAndroid Build Coastguard Worker   {
4563*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1, 1, 1, 1});
4564*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 2});
4565*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4566*da0073e9SAndroid Build Coastguard Worker         gru(input, hx), "Expected input to be 1D or 2D, got 5D instead");
4567*da0073e9SAndroid Build Coastguard Worker   }
4568*da0073e9SAndroid Build Coastguard Worker 
4569*da0073e9SAndroid Build Coastguard Worker   {
4570*da0073e9SAndroid Build Coastguard Worker     auto input = torch::randn({3, 1});
4571*da0073e9SAndroid Build Coastguard Worker     auto hx = torch::randn({3, 1, 1, 1, 2});
4572*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
4573*da0073e9SAndroid Build Coastguard Worker         gru(input, hx), "Expected hidden to be 1D or 2D, got 5D instead");
4574*da0073e9SAndroid Build Coastguard Worker   }
4575*da0073e9SAndroid Build Coastguard Worker }
4576*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLinear)4577*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLinear) {
4578*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4579*da0073e9SAndroid Build Coastguard Worker       c10::str(Linear(3, 4)),
4580*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Linear(in_features=3, out_features=4, bias=true)");
4581*da0073e9SAndroid Build Coastguard Worker }
4582*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintBilinear)4583*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBilinear) {
4584*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4585*da0073e9SAndroid Build Coastguard Worker       c10::str(Bilinear(3, 2, 4)),
4586*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)");
4587*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4588*da0073e9SAndroid Build Coastguard Worker       c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))),
4589*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)");
4590*da0073e9SAndroid Build Coastguard Worker }
4591*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintConv)4592*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintConv) {
4593*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4594*da0073e9SAndroid Build Coastguard Worker       c10::str(Conv1d(3, 4, 5)),
4595*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
4596*da0073e9SAndroid Build Coastguard Worker 
4597*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4598*da0073e9SAndroid Build Coastguard Worker       c10::str(Conv2d(3, 4, 5)),
4599*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4600*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4601*da0073e9SAndroid Build Coastguard Worker       c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
4602*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4603*da0073e9SAndroid Build Coastguard Worker   {
4604*da0073e9SAndroid Build Coastguard Worker     const auto options =
4605*da0073e9SAndroid Build Coastguard Worker         Conv2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4606*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
4607*da0073e9SAndroid Build Coastguard Worker         c10::str(Conv2d(options)),
4608*da0073e9SAndroid Build Coastguard Worker         "torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4609*da0073e9SAndroid Build Coastguard Worker   }
4610*da0073e9SAndroid Build Coastguard Worker 
4611*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4612*da0073e9SAndroid Build Coastguard Worker       c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4613*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4614*da0073e9SAndroid Build Coastguard Worker   {
4615*da0073e9SAndroid Build Coastguard Worker     const auto options = Conv3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4616*da0073e9SAndroid Build Coastguard Worker                              .stride({1, 2, 3})
4617*da0073e9SAndroid Build Coastguard Worker                              .padding(1)
4618*da0073e9SAndroid Build Coastguard Worker                              .dilation(0)
4619*da0073e9SAndroid Build Coastguard Worker                              .groups(2)
4620*da0073e9SAndroid Build Coastguard Worker                              .bias(false)
4621*da0073e9SAndroid Build Coastguard Worker                              .padding_mode(torch::kCircular);
4622*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
4623*da0073e9SAndroid Build Coastguard Worker         c10::str(Conv3d(options)),
4624*da0073e9SAndroid Build Coastguard Worker         "torch::nn::Conv3d("
4625*da0073e9SAndroid Build Coastguard Worker         "4, "
4626*da0073e9SAndroid Build Coastguard Worker         "4, "
4627*da0073e9SAndroid Build Coastguard Worker         "kernel_size=[5, 6, 7], "
4628*da0073e9SAndroid Build Coastguard Worker         "stride=[1, 2, 3], "
4629*da0073e9SAndroid Build Coastguard Worker         "padding=[1, 1, 1], "
4630*da0073e9SAndroid Build Coastguard Worker         "dilation=[0, 0, 0], "
4631*da0073e9SAndroid Build Coastguard Worker         "groups=2, "
4632*da0073e9SAndroid Build Coastguard Worker         "bias=false, "
4633*da0073e9SAndroid Build Coastguard Worker         "padding_mode=kCircular)");
4634*da0073e9SAndroid Build Coastguard Worker   }
4635*da0073e9SAndroid Build Coastguard Worker }
4636*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintConvTranspose)4637*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintConvTranspose) {
4638*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4639*da0073e9SAndroid Build Coastguard Worker       c10::str(ConvTranspose1d(3, 4, 5)),
4640*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConvTranspose1d(3, 4, kernel_size=5, stride=1)");
4641*da0073e9SAndroid Build Coastguard Worker 
4642*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4643*da0073e9SAndroid Build Coastguard Worker       c10::str(ConvTranspose2d(3, 4, 5)),
4644*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4645*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4646*da0073e9SAndroid Build Coastguard Worker       c10::str(ConvTranspose2d(ConvTranspose2dOptions(3, 4, 5).stride(2))),
4647*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4648*da0073e9SAndroid Build Coastguard Worker   {
4649*da0073e9SAndroid Build Coastguard Worker     const auto options =
4650*da0073e9SAndroid Build Coastguard Worker         ConvTranspose2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4651*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
4652*da0073e9SAndroid Build Coastguard Worker         c10::str(ConvTranspose2d(options)),
4653*da0073e9SAndroid Build Coastguard Worker         "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4654*da0073e9SAndroid Build Coastguard Worker   }
4655*da0073e9SAndroid Build Coastguard Worker 
4656*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4657*da0073e9SAndroid Build Coastguard Worker       c10::str(ConvTranspose3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4658*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConvTranspose3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4659*da0073e9SAndroid Build Coastguard Worker   {
4660*da0073e9SAndroid Build Coastguard Worker     const auto options =
4661*da0073e9SAndroid Build Coastguard Worker         ConvTranspose3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4662*da0073e9SAndroid Build Coastguard Worker             .stride({1, 2, 3})
4663*da0073e9SAndroid Build Coastguard Worker             .padding(1)
4664*da0073e9SAndroid Build Coastguard Worker             .dilation(0)
4665*da0073e9SAndroid Build Coastguard Worker             .groups(2)
4666*da0073e9SAndroid Build Coastguard Worker             .bias(false)
4667*da0073e9SAndroid Build Coastguard Worker             .padding_mode(torch::kCircular);
4668*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
4669*da0073e9SAndroid Build Coastguard Worker         c10::str(ConvTranspose3d(options)),
4670*da0073e9SAndroid Build Coastguard Worker         "torch::nn::ConvTranspose3d("
4671*da0073e9SAndroid Build Coastguard Worker         "4, "
4672*da0073e9SAndroid Build Coastguard Worker         "4, "
4673*da0073e9SAndroid Build Coastguard Worker         "kernel_size=[5, 6, 7], "
4674*da0073e9SAndroid Build Coastguard Worker         "stride=[1, 2, 3], "
4675*da0073e9SAndroid Build Coastguard Worker         "padding=[1, 1, 1], "
4676*da0073e9SAndroid Build Coastguard Worker         "dilation=[0, 0, 0], "
4677*da0073e9SAndroid Build Coastguard Worker         "groups=2, "
4678*da0073e9SAndroid Build Coastguard Worker         "bias=false, "
4679*da0073e9SAndroid Build Coastguard Worker         "padding_mode=kCircular)");
4680*da0073e9SAndroid Build Coastguard Worker   }
4681*da0073e9SAndroid Build Coastguard Worker }
4682*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintUpsample)4683*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintUpsample) {
4684*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4685*da0073e9SAndroid Build Coastguard Worker       c10::str(
4686*da0073e9SAndroid Build Coastguard Worker           Upsample(UpsampleOptions().size(std::vector<int64_t>({2, 4, 4})))),
4687*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)");
4688*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4689*da0073e9SAndroid Build Coastguard Worker       c10::str(Upsample(UpsampleOptions()
4690*da0073e9SAndroid Build Coastguard Worker                             .scale_factor(std::vector<double>({0.5, 1.5}))
4691*da0073e9SAndroid Build Coastguard Worker                             .mode(torch::kBilinear))),
4692*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)");
4693*da0073e9SAndroid Build Coastguard Worker }
4694*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintFold)4695*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintFold) {
4696*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4697*da0073e9SAndroid Build Coastguard Worker       c10::str(Fold(FoldOptions({2, 2}, {5, 5}))),
4698*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4699*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4700*da0073e9SAndroid Build Coastguard Worker       c10::str(Fold(
4701*da0073e9SAndroid Build Coastguard Worker           FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))),
4702*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4703*da0073e9SAndroid Build Coastguard Worker }
4704*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintUnfold)4705*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintUnfold) {
4706*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4707*da0073e9SAndroid Build Coastguard Worker       c10::str(Unfold(torch::IntArrayRef({2, 4}))),
4708*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4709*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4710*da0073e9SAndroid Build Coastguard Worker       c10::str(
4711*da0073e9SAndroid Build Coastguard Worker           Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))),
4712*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4713*da0073e9SAndroid Build Coastguard Worker }
4714*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMaxPool)4715*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMaxPool) {
4716*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4717*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool1d(5)),
4718*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=false)");
4719*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4720*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool2d(5)),
4721*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4722*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4723*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))),
4724*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4725*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4726*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool3d(5)),
4727*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4728*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4729*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool3d(MaxPool3dOptions(5).stride(2))),
4730*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4731*da0073e9SAndroid Build Coastguard Worker 
4732*da0073e9SAndroid Build Coastguard Worker   const auto options =
4733*da0073e9SAndroid Build Coastguard Worker       MaxPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4734*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4735*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxPool2d(options)),
4736*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4737*da0073e9SAndroid Build Coastguard Worker }
4738*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintAvgPool)4739*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintAvgPool) {
4740*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4741*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool1d(5)),
4742*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool1d(kernel_size=5, stride=5, padding=0)");
4743*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4744*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool2d(5)),
4745*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4746*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4747*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))),
4748*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0])");
4749*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4750*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool3d(5)),
4751*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0])");
4752*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4753*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool3d(AvgPool3dOptions(5).stride(2))),
4754*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0])");
4755*da0073e9SAndroid Build Coastguard Worker 
4756*da0073e9SAndroid Build Coastguard Worker   const auto options =
4757*da0073e9SAndroid Build Coastguard Worker       AvgPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4758*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4759*da0073e9SAndroid Build Coastguard Worker       c10::str(AvgPool2d(options)),
4760*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0])");
4761*da0073e9SAndroid Build Coastguard Worker }
4762*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrinFractionalMaxPool)4763*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) {
4764*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4765*da0073e9SAndroid Build Coastguard Worker       c10::str(
4766*da0073e9SAndroid Build Coastguard Worker           FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))),
4767*da0073e9SAndroid Build Coastguard Worker       "torch::nn::FractionalMaxPool2d()");
4768*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4769*da0073e9SAndroid Build Coastguard Worker       c10::str(
4770*da0073e9SAndroid Build Coastguard Worker           FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))),
4771*da0073e9SAndroid Build Coastguard Worker       "torch::nn::FractionalMaxPool3d()");
4772*da0073e9SAndroid Build Coastguard Worker }
4773*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLPPool)4774*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLPPool) {
4775*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4776*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool1d(2, 5)),
4777*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool1d(norm_type=2, kernel_size=5, stride=5, ceil_mode=false)");
4778*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4779*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool1d(LPPool1dOptions(1, 2).stride(5).ceil_mode(true))),
4780*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool1d(norm_type=1, kernel_size=2, stride=5, ceil_mode=true)");
4781*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4782*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool2d(2, std::vector<int64_t>({1, 2}))),
4783*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)");
4784*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4785*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool2d(LPPool2dOptions(1, std::vector<int64_t>({3, 4}))
4786*da0073e9SAndroid Build Coastguard Worker                             .stride({5, 6})
4787*da0073e9SAndroid Build Coastguard Worker                             .ceil_mode(true))),
4788*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
4789*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4790*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool3d(2, std::vector<int64_t>({1, 2, 3}))),
4791*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool3d(norm_type=2, kernel_size=[1, 2, 3], stride=[1, 2, 3], ceil_mode=false)");
4792*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4793*da0073e9SAndroid Build Coastguard Worker       c10::str(LPPool3d(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5}))
4794*da0073e9SAndroid Build Coastguard Worker                             .stride({5, 6, 7})
4795*da0073e9SAndroid Build Coastguard Worker                             .ceil_mode(true))),
4796*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LPPool3d(norm_type=1, kernel_size=[3, 4, 5], stride=[5, 6, 7], ceil_mode=true)");
4797*da0073e9SAndroid Build Coastguard Worker }
4798*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintAdaptiveMaxPool)4799*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
4800*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4801*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool1d(5)),
4802*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool1d(output_size=5)");
4803*da0073e9SAndroid Build Coastguard Worker 
4804*da0073e9SAndroid Build Coastguard Worker   const auto options = AdaptiveMaxPool1dOptions(3);
4805*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4806*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool1d(options)),
4807*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool1d(output_size=3)");
4808*da0073e9SAndroid Build Coastguard Worker 
4809*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4810*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool2d(5)),
4811*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool2d(output_size=[5, 5])");
4812*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4813*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, 6}))),
4814*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool2d(output_size=[5, 6])");
4815*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4816*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, std::nullopt}))),
4817*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])");
4818*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4819*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool2d(
4820*da0073e9SAndroid Build Coastguard Worker           AdaptiveMaxPool2dOptions({std::nullopt, std::nullopt}))),
4821*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])");
4822*da0073e9SAndroid Build Coastguard Worker 
4823*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4824*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool3d(5)),
4825*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool3d(output_size=[5, 5, 5])");
4826*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4827*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))),
4828*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])");
4829*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4830*da0073e9SAndroid Build Coastguard Worker       c10::str(
4831*da0073e9SAndroid Build Coastguard Worker           AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, std::nullopt, 7}))),
4832*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])");
4833*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4834*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions(
4835*da0073e9SAndroid Build Coastguard Worker           {std::nullopt, std::nullopt, std::nullopt}))),
4836*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])");
4837*da0073e9SAndroid Build Coastguard Worker }
4838*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintAdaptiveAvgPool)4839*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
4840*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4841*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool1d(5)),
4842*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool1d(output_size=5)");
4843*da0073e9SAndroid Build Coastguard Worker 
4844*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4845*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool2d(5)),
4846*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
4847*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4848*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, 6}))),
4849*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
4850*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4851*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, std::nullopt}))),
4852*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])");
4853*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4854*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool2d(
4855*da0073e9SAndroid Build Coastguard Worker           AdaptiveAvgPool2dOptions({std::nullopt, std::nullopt}))),
4856*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])");
4857*da0073e9SAndroid Build Coastguard Worker 
4858*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4859*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool3d(5)),
4860*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool3d(output_size=[5, 5, 5])");
4861*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4862*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))),
4863*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])");
4864*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4865*da0073e9SAndroid Build Coastguard Worker       c10::str(
4866*da0073e9SAndroid Build Coastguard Worker           AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, std::nullopt, 7}))),
4867*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])");
4868*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4869*da0073e9SAndroid Build Coastguard Worker       c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions(
4870*da0073e9SAndroid Build Coastguard Worker           {std::nullopt, std::nullopt, std::nullopt}))),
4871*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])");
4872*da0073e9SAndroid Build Coastguard Worker }
4873*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMaxUnpool)4874*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
4875*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4876*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxUnpool1d(5)),
4877*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxUnpool1d(kernel_size=5, stride=5, padding=0)");
4878*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4879*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxUnpool1d(MaxUnpool1dOptions(5).stride(3).padding(1))),
4880*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxUnpool1d(kernel_size=5, stride=3, padding=1)");
4881*da0073e9SAndroid Build Coastguard Worker 
4882*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4883*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxUnpool2d(5)),
4884*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxUnpool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4885*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4886*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxUnpool2d(std::vector<int64_t>{5, 6})),
4887*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])");
4888*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4889*da0073e9SAndroid Build Coastguard Worker       c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector<int64_t>{5, 6})
4890*da0073e9SAndroid Build Coastguard Worker                                .stride({3, 4})
4891*da0073e9SAndroid Build Coastguard Worker                                .padding({1, 2}))),
4892*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])");
4893*da0073e9SAndroid Build Coastguard Worker }
4894*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintDropout)4895*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintDropout) {
4896*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)");
4897*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4898*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)");
4899*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4900*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))),
4901*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Dropout(p=0.42, inplace=true)");
4902*da0073e9SAndroid Build Coastguard Worker }
4903*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintDropout2d)4904*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintDropout2d) {
4905*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4906*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)");
4907*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4908*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)");
4909*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4910*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))),
4911*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Dropout2d(p=0.42, inplace=true)");
4912*da0073e9SAndroid Build Coastguard Worker }
4913*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintDropout3d)4914*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintDropout3d) {
4915*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4916*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)");
4917*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4918*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)");
4919*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4920*da0073e9SAndroid Build Coastguard Worker       c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))),
4921*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Dropout3d(p=0.42, inplace=true)");
4922*da0073e9SAndroid Build Coastguard Worker }
4923*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintFunctional)4924*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintFunctional) {
4925*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
4926*da0073e9SAndroid Build Coastguard Worker }
4927*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintBatchNorm1d)4928*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
4929*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4930*da0073e9SAndroid Build Coastguard Worker       c10::str(BatchNorm1d(BatchNorm1dOptions(4)
4931*da0073e9SAndroid Build Coastguard Worker                                .eps(0.5)
4932*da0073e9SAndroid Build Coastguard Worker                                .momentum(0.1)
4933*da0073e9SAndroid Build Coastguard Worker                                .affine(false)
4934*da0073e9SAndroid Build Coastguard Worker                                .track_running_stats(true))),
4935*da0073e9SAndroid Build Coastguard Worker       "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4936*da0073e9SAndroid Build Coastguard Worker }
4937*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintBatchNorm2d)4938*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBatchNorm2d) {
4939*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4940*da0073e9SAndroid Build Coastguard Worker       c10::str(BatchNorm2d(BatchNorm2dOptions(4)
4941*da0073e9SAndroid Build Coastguard Worker                                .eps(0.5)
4942*da0073e9SAndroid Build Coastguard Worker                                .momentum(0.1)
4943*da0073e9SAndroid Build Coastguard Worker                                .affine(false)
4944*da0073e9SAndroid Build Coastguard Worker                                .track_running_stats(true))),
4945*da0073e9SAndroid Build Coastguard Worker       "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4946*da0073e9SAndroid Build Coastguard Worker }
4947*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintBatchNorm3d)4948*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
4949*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4950*da0073e9SAndroid Build Coastguard Worker       c10::str(BatchNorm3d(BatchNorm3dOptions(4)
4951*da0073e9SAndroid Build Coastguard Worker                                .eps(0.5)
4952*da0073e9SAndroid Build Coastguard Worker                                .momentum(0.1)
4953*da0073e9SAndroid Build Coastguard Worker                                .affine(false)
4954*da0073e9SAndroid Build Coastguard Worker                                .track_running_stats(true))),
4955*da0073e9SAndroid Build Coastguard Worker       "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4956*da0073e9SAndroid Build Coastguard Worker }
4957*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintInstanceNorm1d)4958*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
4959*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4960*da0073e9SAndroid Build Coastguard Worker       c10::str(InstanceNorm1d(InstanceNorm1dOptions(4)
4961*da0073e9SAndroid Build Coastguard Worker                                   .eps(0.5)
4962*da0073e9SAndroid Build Coastguard Worker                                   .momentum(0.1)
4963*da0073e9SAndroid Build Coastguard Worker                                   .affine(false)
4964*da0073e9SAndroid Build Coastguard Worker                                   .track_running_stats(true))),
4965*da0073e9SAndroid Build Coastguard Worker       "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4966*da0073e9SAndroid Build Coastguard Worker }
4967*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintInstanceNorm2d)4968*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
4969*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4970*da0073e9SAndroid Build Coastguard Worker       c10::str(InstanceNorm2d(InstanceNorm2dOptions(4)
4971*da0073e9SAndroid Build Coastguard Worker                                   .eps(0.5)
4972*da0073e9SAndroid Build Coastguard Worker                                   .momentum(0.1)
4973*da0073e9SAndroid Build Coastguard Worker                                   .affine(false)
4974*da0073e9SAndroid Build Coastguard Worker                                   .track_running_stats(true))),
4975*da0073e9SAndroid Build Coastguard Worker       "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4976*da0073e9SAndroid Build Coastguard Worker }
4977*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintInstanceNorm3d)4978*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
4979*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4980*da0073e9SAndroid Build Coastguard Worker       c10::str(InstanceNorm3d(InstanceNorm3dOptions(4)
4981*da0073e9SAndroid Build Coastguard Worker                                   .eps(0.5)
4982*da0073e9SAndroid Build Coastguard Worker                                   .momentum(0.1)
4983*da0073e9SAndroid Build Coastguard Worker                                   .affine(false)
4984*da0073e9SAndroid Build Coastguard Worker                                   .track_running_stats(true))),
4985*da0073e9SAndroid Build Coastguard Worker       "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4986*da0073e9SAndroid Build Coastguard Worker }
4987*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLayerNorm)4988*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLayerNorm) {
4989*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4990*da0073e9SAndroid Build Coastguard Worker       c10::str(LayerNorm(LayerNormOptions({2, 2}))),
4991*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)");
4992*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
4993*da0073e9SAndroid Build Coastguard Worker       c10::str(LayerNorm(
4994*da0073e9SAndroid Build Coastguard Worker           LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))),
4995*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
4996*da0073e9SAndroid Build Coastguard Worker }
4997*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintGroupNorm)4998*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintGroupNorm) {
4999*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5000*da0073e9SAndroid Build Coastguard Worker       c10::str(GroupNorm(GroupNormOptions(2, 2))),
5001*da0073e9SAndroid Build Coastguard Worker       "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)");
5002*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5003*da0073e9SAndroid Build Coastguard Worker       c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))),
5004*da0073e9SAndroid Build Coastguard Worker       "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)");
5005*da0073e9SAndroid Build Coastguard Worker }
5006*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLocalResponseNorm)5007*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
5008*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5009*da0073e9SAndroid Build Coastguard Worker       c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
5010*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
5011*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5012*da0073e9SAndroid Build Coastguard Worker       c10::str(LocalResponseNorm(
5013*da0073e9SAndroid Build Coastguard Worker           LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
5014*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
5015*da0073e9SAndroid Build Coastguard Worker }
5016*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintEmbedding)5017*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintEmbedding) {
5018*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5019*da0073e9SAndroid Build Coastguard Worker       c10::str(Embedding(EmbeddingOptions(10, 2))),
5020*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
5021*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5022*da0073e9SAndroid Build Coastguard Worker       c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),
5023*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
5024*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5025*da0073e9SAndroid Build Coastguard Worker       c10::str(Embedding(EmbeddingOptions(10, 2)
5026*da0073e9SAndroid Build Coastguard Worker                              .padding_idx(3)
5027*da0073e9SAndroid Build Coastguard Worker                              .max_norm(2)
5028*da0073e9SAndroid Build Coastguard Worker                              .norm_type(2.5)
5029*da0073e9SAndroid Build Coastguard Worker                              .scale_grad_by_freq(true)
5030*da0073e9SAndroid Build Coastguard Worker                              .sparse(true))),
5031*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5032*da0073e9SAndroid Build Coastguard Worker }
5033*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintEmbeddingBag)5034*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
5035*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5036*da0073e9SAndroid Build Coastguard Worker       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))),
5037*da0073e9SAndroid Build Coastguard Worker       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)");
5038*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5039*da0073e9SAndroid Build Coastguard Worker       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))),
5040*da0073e9SAndroid Build Coastguard Worker       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)");
5041*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5042*da0073e9SAndroid Build Coastguard Worker       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5043*da0073e9SAndroid Build Coastguard Worker                                 .max_norm(2)
5044*da0073e9SAndroid Build Coastguard Worker                                 .norm_type(2.5)
5045*da0073e9SAndroid Build Coastguard Worker                                 .scale_grad_by_freq(true)
5046*da0073e9SAndroid Build Coastguard Worker                                 .sparse(true))),
5047*da0073e9SAndroid Build Coastguard Worker       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
5048*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5049*da0073e9SAndroid Build Coastguard Worker       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5050*da0073e9SAndroid Build Coastguard Worker                                 .max_norm(2)
5051*da0073e9SAndroid Build Coastguard Worker                                 .norm_type(2.5)
5052*da0073e9SAndroid Build Coastguard Worker                                 .scale_grad_by_freq(true)
5053*da0073e9SAndroid Build Coastguard Worker                                 .sparse(true)
5054*da0073e9SAndroid Build Coastguard Worker                                 .mode(torch::kSum))),
5055*da0073e9SAndroid Build Coastguard Worker       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)");
5056*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5057*da0073e9SAndroid Build Coastguard Worker       c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
5058*da0073e9SAndroid Build Coastguard Worker                                 .max_norm(2)
5059*da0073e9SAndroid Build Coastguard Worker                                 .norm_type(2.5)
5060*da0073e9SAndroid Build Coastguard Worker                                 .scale_grad_by_freq(true)
5061*da0073e9SAndroid Build Coastguard Worker                                 .sparse(true)
5062*da0073e9SAndroid Build Coastguard Worker                                 .mode(torch::kSum)
5063*da0073e9SAndroid Build Coastguard Worker                                 .padding_idx(5))),
5064*da0073e9SAndroid Build Coastguard Worker       "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)");
5065*da0073e9SAndroid Build Coastguard Worker }
5066*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintL1Loss)5067*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintL1Loss) {
5068*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()");
5069*da0073e9SAndroid Build Coastguard Worker }
TEST_F(ModulesTest,PrettyPrintKLDivLoss)5070*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintKLDivLoss) {
5071*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()");
5072*da0073e9SAndroid Build Coastguard Worker }
TEST_F(ModulesTest,PrettyPrintMSELoss)5073*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMSELoss) {
5074*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()");
5075*da0073e9SAndroid Build Coastguard Worker }
TEST_F(ModulesTest,PrettyPrintBCELoss)5076*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBCELoss) {
5077*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()");
5078*da0073e9SAndroid Build Coastguard Worker }
TEST_F(ModulesTest,PrettyPrintHingeEmbeddingLoss)5079*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) {
5080*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5081*da0073e9SAndroid Build Coastguard Worker       c10::str(HingeEmbeddingLoss(HingeEmbeddingLossOptions().margin(4))),
5082*da0073e9SAndroid Build Coastguard Worker       "torch::nn::HingeEmbeddingLoss(margin=4)");
5083*da0073e9SAndroid Build Coastguard Worker }
5084*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintCosineEmbeddingLoss)5085*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) {
5086*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5087*da0073e9SAndroid Build Coastguard Worker       c10::str(CosineEmbeddingLoss(CosineEmbeddingLossOptions().margin(0.25))),
5088*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CosineEmbeddingLoss(margin=0.25)");
5089*da0073e9SAndroid Build Coastguard Worker }
5090*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintTripletMarginLoss)5091*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
5092*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5093*da0073e9SAndroid Build Coastguard Worker       c10::str(TripletMarginLoss(
5094*da0073e9SAndroid Build Coastguard Worker           TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))),
5095*da0073e9SAndroid Build Coastguard Worker       "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
5096*da0073e9SAndroid Build Coastguard Worker }
5097*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintTripletMarginWithDistanceLoss)5098*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
5099*da0073e9SAndroid Build Coastguard Worker   auto distanceOptions = TripletMarginWithDistanceLossOptions()
5100*da0073e9SAndroid Build Coastguard Worker                              .distance_function([&](const torch::Tensor& x,
5101*da0073e9SAndroid Build Coastguard Worker                                                     const torch::Tensor& y) {
5102*da0073e9SAndroid Build Coastguard Worker                                return torch::pairwise_distance(x, y, 2.0, 1e-6);
5103*da0073e9SAndroid Build Coastguard Worker                              })
5104*da0073e9SAndroid Build Coastguard Worker                              .margin(1.5)
5105*da0073e9SAndroid Build Coastguard Worker                              .swap(true)
5106*da0073e9SAndroid Build Coastguard Worker                              .reduction(torch::kMean);
5107*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5108*da0073e9SAndroid Build Coastguard Worker       c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
5109*da0073e9SAndroid Build Coastguard Worker       "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
5110*da0073e9SAndroid Build Coastguard Worker }
5111*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintNLLLoss)5112*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintNLLLoss) {
5113*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()");
5114*da0073e9SAndroid Build Coastguard Worker }
5115*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrinCrossEntropyLoss)5116*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) {
5117*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()");
5118*da0073e9SAndroid Build Coastguard Worker }
5119*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMultiLabelMarginLoss)5120*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) {
5121*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5122*da0073e9SAndroid Build Coastguard Worker       c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()");
5123*da0073e9SAndroid Build Coastguard Worker }
5124*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMultiLabelSoftMarginLoss)5125*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
5126*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5127*da0073e9SAndroid Build Coastguard Worker       c10::str(MultiLabelSoftMarginLoss()),
5128*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MultiLabelSoftMarginLoss()");
5129*da0073e9SAndroid Build Coastguard Worker }
5130*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftMarginLoss)5131*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) {
5132*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()");
5133*da0073e9SAndroid Build Coastguard Worker }
5134*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintCosineSimilarity)5135*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintCosineSimilarity) {
5136*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5137*da0073e9SAndroid Build Coastguard Worker       c10::str(CosineSimilarity()),
5138*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CosineSimilarity(dim=1, eps=1e-08)");
5139*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5140*da0073e9SAndroid Build Coastguard Worker       c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))),
5141*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CosineSimilarity(dim=0, eps=0.5)");
5142*da0073e9SAndroid Build Coastguard Worker }
5143*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintPairwiseDistance)5144*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintPairwiseDistance) {
5145*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5146*da0073e9SAndroid Build Coastguard Worker       c10::str(PairwiseDistance()),
5147*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)");
5148*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5149*da0073e9SAndroid Build Coastguard Worker       c10::str(PairwiseDistance(
5150*da0073e9SAndroid Build Coastguard Worker           PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))),
5151*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)");
5152*da0073e9SAndroid Build Coastguard Worker }
5153*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintReflectionPad)5154*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintReflectionPad) {
5155*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5156*da0073e9SAndroid Build Coastguard Worker       c10::str(ReflectionPad1d(ReflectionPad1dOptions(2))),
5157*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReflectionPad1d(padding=[2, 2])");
5158*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5159*da0073e9SAndroid Build Coastguard Worker       c10::str(ReflectionPad1d(ReflectionPad1dOptions({3, 1}))),
5160*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReflectionPad1d(padding=[3, 1])");
5161*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5162*da0073e9SAndroid Build Coastguard Worker       c10::str(ReflectionPad2d(ReflectionPad2dOptions(2))),
5163*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReflectionPad2d(padding=[2, 2, 2, 2])");
5164*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5165*da0073e9SAndroid Build Coastguard Worker       c10::str(ReflectionPad2d(ReflectionPad2dOptions({1, 1, 2, 0}))),
5166*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReflectionPad2d(padding=[1, 1, 2, 0])");
5167*da0073e9SAndroid Build Coastguard Worker }
5168*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintReplicationPad)5169*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintReplicationPad) {
5170*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5171*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad1d(ReplicationPad1dOptions(2))),
5172*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad1d(padding=[2, 2])");
5173*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5174*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad1d(ReplicationPad1dOptions({3, 1}))),
5175*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad1d(padding=[3, 1])");
5176*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5177*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad2d(ReplicationPad2dOptions(2))),
5178*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad2d(padding=[2, 2, 2, 2])");
5179*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5180*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad2d(ReplicationPad2dOptions({1, 1, 2, 0}))),
5181*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad2d(padding=[1, 1, 2, 0])");
5182*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5183*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad3d(ReplicationPad3dOptions(1))),
5184*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad3d(padding=[1, 1, 1, 1, 1, 1])");
5185*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5186*da0073e9SAndroid Build Coastguard Worker       c10::str(ReplicationPad3d(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}))),
5187*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
5188*da0073e9SAndroid Build Coastguard Worker }
5189*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintZeroPad)5190*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintZeroPad) {
5191*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5192*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad1d(ZeroPad1dOptions(2))),
5193*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad1d(padding=[2, 2])");
5194*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5195*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad1d(ZeroPad1dOptions({3, 1}))),
5196*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad1d(padding=[3, 1])");
5197*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5198*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
5199*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
5200*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5201*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
5202*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
5203*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5204*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad3d(ZeroPad3dOptions(1))),
5205*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad3d(padding=[1, 1, 1, 1, 1, 1])");
5206*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5207*da0073e9SAndroid Build Coastguard Worker       c10::str(ZeroPad3d(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}))),
5208*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ZeroPad3d(padding=[1, 2, 1, 2, 1, 2])");
5209*da0073e9SAndroid Build Coastguard Worker }
5210*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintConstantPad)5211*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintConstantPad) {
5212*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5213*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad1d(ConstantPad1dOptions(2, 3.5))),
5214*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad1d(padding=[2, 2], value=3.5)");
5215*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5216*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad1d(ConstantPad1dOptions({3, 1}, 3.5))),
5217*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad1d(padding=[3, 1], value=3.5)");
5218*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5219*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad2d(ConstantPad2dOptions(2, 3.5))),
5220*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad2d(padding=[2, 2, 2, 2], value=3.5)");
5221*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5222*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad2d(ConstantPad2dOptions({3, 0, 2, 1}, 3.5))),
5223*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad2d(padding=[3, 0, 2, 1], value=3.5)");
5224*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5225*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad3d(ConstantPad3dOptions(1, 3.5))),
5226*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad3d(padding=[1, 1, 1, 1, 1, 1], value=3.5)");
5227*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5228*da0073e9SAndroid Build Coastguard Worker       c10::str(ConstantPad3d(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5))),
5229*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ConstantPad3d(padding=[1, 2, 1, 2, 1, 2], value=3.5)");
5230*da0073e9SAndroid Build Coastguard Worker }
5231*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintNestedModel)5232*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintNestedModel) {
5233*da0073e9SAndroid Build Coastguard Worker   struct InnerTestModule : torch::nn::Module {
5234*da0073e9SAndroid Build Coastguard Worker     InnerTestModule()
5235*da0073e9SAndroid Build Coastguard Worker         : torch::nn::Module("InnerTestModule"),
5236*da0073e9SAndroid Build Coastguard Worker           fc(register_module("fc", torch::nn::Linear(3, 4))),
5237*da0073e9SAndroid Build Coastguard Worker           table(register_module("table", torch::nn::Embedding(10, 2))) {}
5238*da0073e9SAndroid Build Coastguard Worker 
5239*da0073e9SAndroid Build Coastguard Worker     torch::nn::Linear fc;
5240*da0073e9SAndroid Build Coastguard Worker     torch::nn::Embedding table;
5241*da0073e9SAndroid Build Coastguard Worker   };
5242*da0073e9SAndroid Build Coastguard Worker 
5243*da0073e9SAndroid Build Coastguard Worker   struct TestModule : torch::nn::Module {
5244*da0073e9SAndroid Build Coastguard Worker     TestModule()
5245*da0073e9SAndroid Build Coastguard Worker         : torch::nn::Module("TestModule"),
5246*da0073e9SAndroid Build Coastguard Worker           fc(register_module("fc", torch::nn::Linear(4, 5))),
5247*da0073e9SAndroid Build Coastguard Worker           table(register_module(
5248*da0073e9SAndroid Build Coastguard Worker               "table",
5249*da0073e9SAndroid Build Coastguard Worker               torch::nn::Embedding(EmbeddingOptions(10, 2)))),
5250*da0073e9SAndroid Build Coastguard Worker           inner(register_module("inner", std::make_shared<InnerTestModule>())) {
5251*da0073e9SAndroid Build Coastguard Worker     }
5252*da0073e9SAndroid Build Coastguard Worker 
5253*da0073e9SAndroid Build Coastguard Worker     torch::nn::Linear fc;
5254*da0073e9SAndroid Build Coastguard Worker     torch::nn::Embedding table;
5255*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<InnerTestModule> inner;
5256*da0073e9SAndroid Build Coastguard Worker   };
5257*da0073e9SAndroid Build Coastguard Worker 
5258*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5259*da0073e9SAndroid Build Coastguard Worker       c10::str(TestModule{}),
5260*da0073e9SAndroid Build Coastguard Worker       "TestModule(\n"
5261*da0073e9SAndroid Build Coastguard Worker       "  (fc): torch::nn::Linear(in_features=4, out_features=5, bias=true)\n"
5262*da0073e9SAndroid Build Coastguard Worker       "  (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5263*da0073e9SAndroid Build Coastguard Worker       "  (inner): InnerTestModule(\n"
5264*da0073e9SAndroid Build Coastguard Worker       "    (fc): torch::nn::Linear(in_features=3, out_features=4, bias=true)\n"
5265*da0073e9SAndroid Build Coastguard Worker       "    (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5266*da0073e9SAndroid Build Coastguard Worker       "  )\n"
5267*da0073e9SAndroid Build Coastguard Worker       ")");
5268*da0073e9SAndroid Build Coastguard Worker }
5269*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintELU)5270*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintELU) {
5271*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)");
5272*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5273*da0073e9SAndroid Build Coastguard Worker       c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))),
5274*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ELU(alpha=42.42, inplace=true)");
5275*da0073e9SAndroid Build Coastguard Worker }
5276*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSELU)5277*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSELU) {
5278*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()");
5279*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5280*da0073e9SAndroid Build Coastguard Worker       c10::str(SELU(SELUOptions().inplace(true))),
5281*da0073e9SAndroid Build Coastguard Worker       "torch::nn::SELU(inplace=true)");
5282*da0073e9SAndroid Build Coastguard Worker }
5283*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintGLU)5284*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintGLU) {
5285*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(GLU()), "torch::nn::GLU(dim=-1)");
5286*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(GLU(1)), "torch::nn::GLU(dim=1)");
5287*da0073e9SAndroid Build Coastguard Worker }
5288*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintHardshrink)5289*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintHardshrink) {
5290*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)");
5291*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5292*da0073e9SAndroid Build Coastguard Worker       c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))),
5293*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Hardshrink(42.42)");
5294*da0073e9SAndroid Build Coastguard Worker }
5295*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintHardtanh)5296*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintHardtanh) {
5297*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)");
5298*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5299*da0073e9SAndroid Build Coastguard Worker       c10::str(Hardtanh(
5300*da0073e9SAndroid Build Coastguard Worker           HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))),
5301*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)");
5302*da0073e9SAndroid Build Coastguard Worker }
5303*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLeakyReLU)5304*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
5305*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)");
5306*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5307*da0073e9SAndroid Build Coastguard Worker       c10::str(
5308*da0073e9SAndroid Build Coastguard Worker           LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))),
5309*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
5310*da0073e9SAndroid Build Coastguard Worker }
5311*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLogSigmoid)5312*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
5313*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
5314*da0073e9SAndroid Build Coastguard Worker }
5315*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftmax)5316*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftmax) {
5317*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)");
5318*da0073e9SAndroid Build Coastguard Worker }
5319*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftmin)5320*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftmin) {
5321*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softmin(SoftminOptions(1))), "torch::nn::Softmin(dim=1)");
5322*da0073e9SAndroid Build Coastguard Worker }
5323*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLogSoftmax)5324*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLogSoftmax) {
5325*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5326*da0073e9SAndroid Build Coastguard Worker       c10::str(LogSoftmax(LogSoftmaxOptions(1))),
5327*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LogSoftmax(dim=1)");
5328*da0073e9SAndroid Build Coastguard Worker }
5329*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftmax2d)5330*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftmax2d) {
5331*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softmax2d()), "torch::nn::Softmax2d()");
5332*da0073e9SAndroid Build Coastguard Worker }
5333*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintPReLU)5334*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintPReLU) {
5335*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)");
5336*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5337*da0073e9SAndroid Build Coastguard Worker       c10::str(PReLU(PReLUOptions().num_parameters(42))),
5338*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PReLU(num_parameters=42)");
5339*da0073e9SAndroid Build Coastguard Worker }
5340*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintReLU)5341*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintReLU) {
5342*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()");
5343*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5344*da0073e9SAndroid Build Coastguard Worker       c10::str(ReLU(ReLUOptions().inplace(true))),
5345*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReLU(inplace=true)");
5346*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)");
5347*da0073e9SAndroid Build Coastguard Worker }
5348*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintReLU6)5349*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintReLU6) {
5350*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()");
5351*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5352*da0073e9SAndroid Build Coastguard Worker       c10::str(ReLU6(ReLU6Options().inplace(true))),
5353*da0073e9SAndroid Build Coastguard Worker       "torch::nn::ReLU6(inplace=true)");
5354*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5355*da0073e9SAndroid Build Coastguard Worker       c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)");
5356*da0073e9SAndroid Build Coastguard Worker }
5357*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintRReLU)5358*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintRReLU) {
5359*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)");
5360*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5361*da0073e9SAndroid Build Coastguard Worker       c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))),
5362*da0073e9SAndroid Build Coastguard Worker       "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)");
5363*da0073e9SAndroid Build Coastguard Worker }
5364*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintCELU)5365*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintCELU) {
5366*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)");
5367*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5368*da0073e9SAndroid Build Coastguard Worker       c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))),
5369*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CELU(alpha=42.42, inplace=true)");
5370*da0073e9SAndroid Build Coastguard Worker }
5371*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSigmoid)5372*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSigmoid) {
5373*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Sigmoid()), "torch::nn::Sigmoid()");
5374*da0073e9SAndroid Build Coastguard Worker }
5375*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintPixelShuffle)5376*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
5377*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5378*da0073e9SAndroid Build Coastguard Worker       c10::str(PixelShuffle(PixelShuffleOptions(5))),
5379*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PixelShuffle(upscale_factor=5)");
5380*da0073e9SAndroid Build Coastguard Worker }
5381*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintPixelUnshuffle)5382*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
5383*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5384*da0073e9SAndroid Build Coastguard Worker       c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
5385*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PixelUnshuffle(downscale_factor=5)");
5386*da0073e9SAndroid Build Coastguard Worker }
5387*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftplus)5388*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftplus) {
5389*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)");
5390*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5391*da0073e9SAndroid Build Coastguard Worker       c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))),
5392*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Softplus(beta=0.24, threshold=42.42)");
5393*da0073e9SAndroid Build Coastguard Worker }
5394*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftshrink)5395*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftshrink) {
5396*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)");
5397*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5398*da0073e9SAndroid Build Coastguard Worker       c10::str(Softshrink(SoftshrinkOptions(42.42))),
5399*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Softshrink(42.42)");
5400*da0073e9SAndroid Build Coastguard Worker }
5401*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintSoftsign)5402*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintSoftsign) {
5403*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()");
5404*da0073e9SAndroid Build Coastguard Worker }
5405*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintTanh)5406*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintTanh) {
5407*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()");
5408*da0073e9SAndroid Build Coastguard Worker }
5409*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintTanhshrink)5410*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintTanhshrink) {
5411*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(Tanhshrink()), "torch::nn::Tanhshrink()");
5412*da0073e9SAndroid Build Coastguard Worker }
5413*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintThreshold)5414*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintThreshold) {
5415*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5416*da0073e9SAndroid Build Coastguard Worker       c10::str(Threshold(24.24, 42.42)),
5417*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Threshold(threshold=24.24, value=42.42)");
5418*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5419*da0073e9SAndroid Build Coastguard Worker       c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))),
5420*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)");
5421*da0073e9SAndroid Build Coastguard Worker }
5422*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintCTCLoss)5423*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintCTCLoss) {
5424*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()");
5425*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5426*da0073e9SAndroid Build Coastguard Worker       c10::str(
5427*da0073e9SAndroid Build Coastguard Worker           CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction(
5428*da0073e9SAndroid Build Coastguard Worker               torch::kSum))),
5429*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CTCLoss()");
5430*da0073e9SAndroid Build Coastguard Worker }
5431*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintPoissonNLLLoss)5432*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
5433*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
5434*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5435*da0073e9SAndroid Build Coastguard Worker       c10::str(PoissonNLLLoss(PoissonNLLLossOptions()
5436*da0073e9SAndroid Build Coastguard Worker                                   .log_input(false)
5437*da0073e9SAndroid Build Coastguard Worker                                   .full(true)
5438*da0073e9SAndroid Build Coastguard Worker                                   .eps(0.42)
5439*da0073e9SAndroid Build Coastguard Worker                                   .reduction(torch::kSum))),
5440*da0073e9SAndroid Build Coastguard Worker       "torch::nn::PoissonNLLLoss()");
5441*da0073e9SAndroid Build Coastguard Worker }
5442*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMarginRankingLoss)5443*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) {
5444*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()");
5445*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5446*da0073e9SAndroid Build Coastguard Worker       c10::str(MarginRankingLoss(
5447*da0073e9SAndroid Build Coastguard Worker           MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))),
5448*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MarginRankingLoss()");
5449*da0073e9SAndroid Build Coastguard Worker }
5450*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintCrossMapLRN2d)5451*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) {
5452*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5453*da0073e9SAndroid Build Coastguard Worker       c10::str(CrossMapLRN2d(4)),
5454*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)");
5455*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5456*da0073e9SAndroid Build Coastguard Worker       c10::str(
5457*da0073e9SAndroid Build Coastguard Worker           CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))),
5458*da0073e9SAndroid Build Coastguard Worker       "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)");
5459*da0073e9SAndroid Build Coastguard Worker }
5460*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintAlphaDropout)5461*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintAlphaDropout) {
5462*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5463*da0073e9SAndroid Build Coastguard Worker       c10::str(AlphaDropout()),
5464*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AlphaDropout(p=0.5, inplace=false)");
5465*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5466*da0073e9SAndroid Build Coastguard Worker       c10::str(AlphaDropout(AlphaDropoutOptions(0.2))),
5467*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AlphaDropout(p=0.2, inplace=false)");
5468*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5469*da0073e9SAndroid Build Coastguard Worker       c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))),
5470*da0073e9SAndroid Build Coastguard Worker       "torch::nn::AlphaDropout(p=0.2, inplace=true)");
5471*da0073e9SAndroid Build Coastguard Worker }
5472*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintFeatureAlphaDropout)5473*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) {
5474*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5475*da0073e9SAndroid Build Coastguard Worker       c10::str(FeatureAlphaDropout()),
5476*da0073e9SAndroid Build Coastguard Worker       "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)");
5477*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5478*da0073e9SAndroid Build Coastguard Worker       c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))),
5479*da0073e9SAndroid Build Coastguard Worker       "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)");
5480*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5481*da0073e9SAndroid Build Coastguard Worker       c10::str(
5482*da0073e9SAndroid Build Coastguard Worker           FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))),
5483*da0073e9SAndroid Build Coastguard Worker       "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)");
5484*da0073e9SAndroid Build Coastguard Worker }
5485*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintBCEWithLogitsLoss)5486*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) {
5487*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()");
5488*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5489*da0073e9SAndroid Build Coastguard Worker       c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions()
5490*da0073e9SAndroid Build Coastguard Worker                                      .weight(torch::ones({3, 3}))
5491*da0073e9SAndroid Build Coastguard Worker                                      .pos_weight(torch::ones({3, 3}))
5492*da0073e9SAndroid Build Coastguard Worker                                      .reduction(torch::kSum))),
5493*da0073e9SAndroid Build Coastguard Worker       "torch::nn::BCEWithLogitsLoss()");
5494*da0073e9SAndroid Build Coastguard Worker }
5495*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintMultiheadAttention)5496*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintMultiheadAttention) {
5497*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5498*da0073e9SAndroid Build Coastguard Worker       c10::str(MultiheadAttention(20, 10)),
5499*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MultiheadAttention(\n  (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)");
5500*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5501*da0073e9SAndroid Build Coastguard Worker       c10::str(
5502*da0073e9SAndroid Build Coastguard Worker           MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
5503*da0073e9SAndroid Build Coastguard Worker       "torch::nn::MultiheadAttention(\n  (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
5504*da0073e9SAndroid Build Coastguard Worker }
5505*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintRNNCell)5506*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintRNNCell) {
5507*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)");
5508*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5509*da0073e9SAndroid Build Coastguard Worker       c10::str(RNNCell(
5510*da0073e9SAndroid Build Coastguard Worker           RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))),
5511*da0073e9SAndroid Build Coastguard Worker       "torch::nn::RNNCell(20, 10, bias=false)");
5512*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5513*da0073e9SAndroid Build Coastguard Worker       c10::str(RNNCell(
5514*da0073e9SAndroid Build Coastguard Worker           RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))),
5515*da0073e9SAndroid Build Coastguard Worker       "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)");
5516*da0073e9SAndroid Build Coastguard Worker }
5517*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintLSTMCell)5518*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintLSTMCell) {
5519*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)");
5520*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5521*da0073e9SAndroid Build Coastguard Worker       c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))),
5522*da0073e9SAndroid Build Coastguard Worker       "torch::nn::LSTMCell(20, 10, bias=false)");
5523*da0073e9SAndroid Build Coastguard Worker }
5524*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintGRUCell)5525*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintGRUCell) {
5526*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)");
5527*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
5528*da0073e9SAndroid Build Coastguard Worker       c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))),
5529*da0073e9SAndroid Build Coastguard Worker       "torch::nn::GRUCell(20, 10, bias=false)");
5530*da0073e9SAndroid Build Coastguard Worker }
5531*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ModulesTest,PrettyPrintAdaptiveLogSoftmaxWithLoss)5532*da0073e9SAndroid Build Coastguard Worker TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
5533*da0073e9SAndroid Build Coastguard Worker   {
5534*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
5535*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
5536*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
5537*da0073e9SAndroid Build Coastguard Worker         c10::str(asfm),
5538*da0073e9SAndroid Build Coastguard Worker         "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5539*da0073e9SAndroid Build Coastguard Worker         "  (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
5540*da0073e9SAndroid Build Coastguard Worker         "  (tail): torch::nn::ModuleList(\n"
5541*da0073e9SAndroid Build Coastguard Worker         "    (0): torch::nn::Sequential(\n"
5542*da0073e9SAndroid Build Coastguard Worker         "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5543*da0073e9SAndroid Build Coastguard Worker         "      (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
5544*da0073e9SAndroid Build Coastguard Worker         "    )\n"
5545*da0073e9SAndroid Build Coastguard Worker         "  )\n"
5546*da0073e9SAndroid Build Coastguard Worker         ")");
5547*da0073e9SAndroid Build Coastguard Worker   }
5548*da0073e9SAndroid Build Coastguard Worker   {
5549*da0073e9SAndroid Build Coastguard Worker     AdaptiveLogSoftmaxWithLoss asfm(
5550*da0073e9SAndroid Build Coastguard Worker         AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
5551*da0073e9SAndroid Build Coastguard Worker             .div_value(2.)
5552*da0073e9SAndroid Build Coastguard Worker             .head_bias(true));
5553*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
5554*da0073e9SAndroid Build Coastguard Worker         c10::str(asfm),
5555*da0073e9SAndroid Build Coastguard Worker         "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5556*da0073e9SAndroid Build Coastguard Worker         "  (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
5557*da0073e9SAndroid Build Coastguard Worker         "  (tail): torch::nn::ModuleList(\n"
5558*da0073e9SAndroid Build Coastguard Worker         "    (0): torch::nn::Sequential(\n"
5559*da0073e9SAndroid Build Coastguard Worker         "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5560*da0073e9SAndroid Build Coastguard Worker         "      (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
5561*da0073e9SAndroid Build Coastguard Worker         "    )\n"
5562*da0073e9SAndroid Build Coastguard Worker         "    (1): torch::nn::Sequential(\n"
5563*da0073e9SAndroid Build Coastguard Worker         "      (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
5564*da0073e9SAndroid Build Coastguard Worker         "      (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
5565*da0073e9SAndroid Build Coastguard Worker         "    )\n"
5566*da0073e9SAndroid Build Coastguard Worker         "  )\n"
5567*da0073e9SAndroid Build Coastguard Worker         ")");
5568*da0073e9SAndroid Build Coastguard Worker   }
5569*da0073e9SAndroid Build Coastguard Worker }
5570