xref: /aosp_15_r20/external/executorch/kernels/test/op_native_batch_norm_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 
20 class OpNativeBatchNormLegitNoTrainingOutTest : public OperatorTest {
21  protected:
22   ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_no_training_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,const exec_aten::Tensor & running_mean,const exec_aten::Tensor & running_var,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)23   op_native_batch_norm_legit_no_training_out(
24       const exec_aten::Tensor& input,
25       const exec_aten::optional<exec_aten::Tensor>& weight,
26       const exec_aten::optional<exec_aten::Tensor>& bias,
27       const exec_aten::Tensor& running_mean,
28       const exec_aten::Tensor& running_var,
29       double momentum,
30       double eps,
31       exec_aten::Tensor& out0,
32       exec_aten::Tensor& out1,
33       exec_aten::Tensor& out2) {
34     return torch::executor::aten::_native_batch_norm_legit_no_training_outf(
35         context_,
36         input,
37         weight,
38         bias,
39         running_mean,
40         running_var,
41         momentum,
42         eps,
43         out0,
44         out1,
45         out2);
46   }
47 };
48 
49 class OpNativeBatchNormLegitOutTest : public OperatorTest {
50  protected:
51   ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,exec_aten::Tensor & running_mean,exec_aten::Tensor & running_var,bool training,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)52   op_native_batch_norm_legit_out(
53       const exec_aten::Tensor& input,
54       const exec_aten::optional<exec_aten::Tensor>& weight,
55       const exec_aten::optional<exec_aten::Tensor>& bias,
56       exec_aten::Tensor& running_mean,
57       exec_aten::Tensor& running_var,
58       bool training,
59       double momentum,
60       double eps,
61       exec_aten::Tensor& out0,
62       exec_aten::Tensor& out1,
63       exec_aten::Tensor& out2) {
64     executorch::runtime::KernelRuntimeContext context{};
65     return torch::executor::aten::_native_batch_norm_legit_outf(
66         context,
67         input,
68         weight,
69         bias,
70         running_mean,
71         running_var,
72         training,
73         momentum,
74         eps,
75         out0,
76         out1,
77         out2);
78   }
79 };
80 
81 class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
82  protected:
83   ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_no_stats_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,bool training,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)84   op_native_batch_norm_legit_no_stats_out(
85       const exec_aten::Tensor& input,
86       const exec_aten::optional<exec_aten::Tensor>& weight,
87       const exec_aten::optional<exec_aten::Tensor>& bias,
88       bool training,
89       double momentum,
90       double eps,
91       exec_aten::Tensor& out0,
92       exec_aten::Tensor& out1,
93       exec_aten::Tensor& out2) {
94     return torch::executor::aten::_native_batch_norm_legit_outf(
95         context_,
96         input,
97         weight,
98         bias,
99         training,
100         momentum,
101         eps,
102         out0,
103         out1,
104         out2);
105   }
106 };
107 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest2D)108 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
109   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
110 
111   exec_aten::Tensor input = tfFloat.make(
112       {4, 7}, {2.876736640930176,  7.67944860458374,   5.701690196990967,
113                9.299789428710938,  3.023690700531006,  5.315116882324219,
114                7.185585021972656,  6.911304473876953,  7.61051082611084,
115                1.4963287115097046, 0.7381612062454224, 8.588483810424805,
116                6.583977699279785,  8.831110000610352,  0.8165055513381958,
117                7.087201118469238,  5.572513580322266,  4.446897983551025,
118                4.444573402404785,  6.254056930541992,  5.906398296356201,
119                9.971039772033691,  3.5423521995544434, 7.452159881591797,
120                9.93700122833252,   1.8560808897018433, 1.524025797843933,
121                7.3222975730896});
122   exec_aten::optional<exec_aten::Tensor> weight =
123       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
124           {7},
125           {8.287437438964844,
126            8.227645874023438,
127            6.65926456451416,
128            9.436124801635742,
129            4.119281768798828,
130            8.593960762023926,
131            2.3760855197906494}));
132   exec_aten::optional<exec_aten::Tensor> bias =
133       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
134           {7},
135           {7.824275970458984,
136            6.84327507019043,
137            8.354326248168945,
138            8.773970603942871,
139            3.89609694480896,
140            3.0753469467163086,
141            3.1105971336364746}));
142   exec_aten::Tensor running_mean = tfFloat.make(
143       {7},
144       {9.700226783752441,
145        0.1234668493270874,
146        7.527220249176025,
147        8.993252754211426,
148        0.4736626148223877,
149        7.7135701179504395,
150        5.12320613861084});
151   exec_aten::Tensor running_var = tfFloat.make(
152       {7},
153       {3.585531234741211,
154        6.615292549133301,
155        0.24084866046905518,
156        5.175800323486328,
157        0.5886000394821167,
158        6.23909854888916,
159        1.5029621124267578});
160   double momentum = 0.1;
161   double eps = 0;
162   exec_aten::Tensor out0 = tfFloat.zeros({4, 7});
163   exec_aten::Tensor out1 = tfFloat.zeros({0});
164   exec_aten::Tensor out2 = tfFloat.zeros({0});
165   exec_aten::Tensor out0_expected = tfFloat.make(
166       {4, 7}, {-22.039867401123047, 31.014127731323242,  -16.416650772094727,
167                10.04538631439209,   17.5877628326416,    -5.17673921585083,
168                7.1078033447265625,  -4.381907939910889,  30.793603897094727,
169                -73.48003387451172,  -25.46548080444336,  47.46636962890625,
170                -0.8111140131950378, 10.29708194732666,   -31.056814193725586,
171                29.119586944580078,  -18.16947364807129,  -10.082839965820312,
172                25.216796875,        -1.9462348222732544, 4.628543376922607,
173                9.00953483581543,    17.779958724975586,  7.335818767547607,
174                12.688335418701172,  11.318607330322266,  -18.22031593322754,
175                7.372773170471191});
176   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
177   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
178   op_native_batch_norm_legit_no_training_out(
179       input,
180       weight,
181       bias,
182       running_mean,
183       running_var,
184       momentum,
185       eps,
186       out0,
187       out1,
188       out2);
189   EXPECT_TENSOR_CLOSE(out0, out0_expected);
190   EXPECT_TENSOR_CLOSE(out1, out1_expected);
191   EXPECT_TENSOR_CLOSE(out2, out2_expected);
192 }
193 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest3D)194 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest3D) {
195   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
196 
197   exec_aten::Tensor input = tfFloat.make(
198       {4, 7, 5}, {5.277339935302734,  5.94276237487793,     6.543086051940918,
199                   2.411264181137085,  8.980886459350586,    2.7123653888702393,
200                   9.466896057128906,  9.324702262878418,    1.9848430156707764,
201                   8.388091087341309,  1.5069717168807983,   5.350819110870361,
202                   1.727534532546997,  4.913003444671631,    2.555372714996338,
203                   4.321412563323975,  1.107364296913147,    6.048641681671143,
204                   9.496582984924316,  0.9668296575546265,   0.8103430271148682,
205                   8.187652587890625,  9.455179214477539,    0.5739009380340576,
206                   3.550161838531494,  1.5362483263015747,   7.338945388793945,
207                   3.583885431289673,  6.5086517333984375,   0.9027481079101562,
208                   0.8805221319198608, 3.983092784881592,    5.43976354598999,
209                   9.080245971679688,  2.602390766143799,    2.1537625789642334,
210                   3.2551045417785645, 7.098634719848633,    8.135055541992188,
211                   7.457048416137695,  5.3438568115234375,   3.7822632789611816,
212                   3.4284191131591797, 6.144853115081787,    9.79615592956543,
213                   5.735219955444336,  2.5468051433563232,   8.514262199401855,
214                   3.775507926940918,  8.327726364135742,    4.772212505340576,
215                   7.100861072540283,  3.477612018585205,    9.359293937683105,
216                   5.203947067260742,  3.6150975227355957,   6.159048557281494,
217                   0.9919929504394531, 1.6809028387069702,   0.3627735376358032,
218                   1.8791186809539795, 4.037001132965088,    8.129783630371094,
219                   4.79802131652832,   2.9911656379699707,   8.659820556640625,
220                   7.378345489501953,  3.6833512783050537,   2.4555420875549316,
221                   8.481515884399414,  3.733121156692505,    6.075705528259277,
222                   6.900073051452637,  6.380939960479736,    3.204977512359619,
223                   2.058135986328125,  4.60728120803833,     7.737727165222168,
224                   5.3178815841674805, 9.224492073059082,    4.838874340057373,
225                   2.717348337173462,  1.8555694818496704,   1.856197714805603,
226                   7.189084053039551,  5.280246257781982,    7.550882816314697,
227                   0.6145977973937988, 6.764681816101074,    4.217874526977539,
228                   0.89302659034729,   2.4634499549865723,   3.51415753364563,
229                   5.038887977600098,  4.948186874389648,    8.326996803283691,
230                   8.919670104980469,  4.45585298538208,     0.5209791660308838,
231                   4.2513017654418945, 0.047875046730041504, 2.453791618347168,
232                   6.113187789916992,  5.47722053527832,     7.524778842926025,
233                   0.3724473714828491, 2.6570069789886475,   9.420238494873047,
234                   4.650344371795654,  4.206380844116211,    1.2107867002487183,
235                   3.3689606189727783, 4.082674980163574,    5.31553840637207,
236                   4.759864807128906,  5.461820602416992,    2.0690488815307617,
237                   9.234517097473145,  1.6740238666534424,   3.492245674133301,
238                   9.844581604003906,  4.278226852416992,    2.9611783027648926,
239                   9.626322746276855,  7.756594657897949,    3.4873299598693848,
240                   6.345180988311768,  5.55388069152832,     8.535417556762695,
241                   8.509242057800293,  8.684778213500977,    3.784114122390747,
242                   3.887125253677368,  9.278786659240723,    6.742891311645508,
243                   5.01821756362915,   2.326876640319824,    7.939553737640381,
244                   3.2622408866882324, 3.829448699951172});
245   exec_aten::optional<exec_aten::Tensor> weight =
246       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
247           {7},
248           {0.5193436145782471,
249            4.531304836273193,
250            8.960723876953125,
251            8.598731994628906,
252            2.6848177909851074,
253            7.309220314025879,
254            2.2476916313171387}));
255   exec_aten::optional<exec_aten::Tensor> bias =
256       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
257           {7},
258           {4.643010139465332,
259            0.2791440486907959,
260            3.6721653938293457,
261            3.918765068054199,
262            2.6499342918395996,
263            5.721188545227051,
264            5.901060104370117}));
265   exec_aten::Tensor running_mean = tfFloat.make(
266       {7},
267       {5.818909645080566,
268        5.325511932373047,
269        7.094021797180176,
270        4.9185566902160645,
271        5.608961582183838,
272        3.7719011306762695,
273        6.7734270095825195});
274   exec_aten::Tensor running_var = tfFloat.make(
275       {7},
276       {8.8593168258667,
277        3.440363883972168,
278        7.105681896209717,
279        1.0423260927200317,
280        6.756608009338379,
281        4.527579307556152,
282        2.022289752960205});
283   double momentum = 0.1;
284   double eps = 0;
285   exec_aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
286   exec_aten::Tensor out1 = tfFloat.zeros({0});
287   exec_aten::Tensor out2 = tfFloat.zeros({0});
288   exec_aten::Tensor out0_expected = tfFloat.make(
289       {4, 7, 5}, {4.5485148429870605,  4.664620399475098,   4.76936674118042,
290                   4.048431873321533,   5.194723129272461,   -6.104737281799316,
291                   10.396490097045898,  10.049112319946289,  -7.8820648193359375,
292                   7.760983943939209,   -15.109009742736816, -2.1877059936523438,
293                   -14.367575645446777, -3.659447431564331,  -11.584752082824707,
294                   -1.1105821132659912, -28.180377960205078, 13.436722755432129,
295                   42.476444244384766,  -29.3640079498291,   -2.306469440460205,
296                   5.313416481018066,   6.622621059417725,   -2.5506861209869385,
297                   0.5234383940696716,  -1.9584782123565674, 17.97430419921875,
298                   5.075337886810303,   15.122170448303223,  -4.134607791900635,
299                   -3.413116931915283,  1.4907281398773193,  3.793105363845825,
300                   9.547160148620605,   -0.69157475233078,   4.003501892089844,
301                   4.1956682205200195,  4.8663010597229,     5.047139644622803,
302                   4.92883825302124,    0.3239605128765106,  -3.4909913539886475,
303                   -4.3554277420043945, 2.2807836532592773,  11.20086669921875,
304                   -0.8955214619636536, -11.613553047180176, 8.446381568908691,
305                   -7.483201026916504,  7.819331169128418,   2.686206579208374,
306                   22.29886817932129,   -8.217354774475098,  41.320152282714844,
307                   6.322420597076416,   0.5905092358589172,  3.218108892440796,
308                   -2.1188466548919678, -1.4072843790054321, -2.7687556743621826,
309                   -0.7806879878044128, 6.63183069229126,    20.690902709960938,
310                   9.246002197265625,   3.039292335510254,   8.882646560668945,
311                   6.857179164886475,   1.016964316368103,   -0.9236800670623779,
312                   8.600822448730469,   4.279074192047119,   4.687816619873047,
313                   4.831655502319336,   4.741075038909912,   4.1869215965271,
314                   -7.7030110359191895, -1.475483775138855,  6.172153472900391,
315                   0.2605033814907074,  9.804300308227539,   -3.9086363315582275,
316                   -11.040262222290039, -13.937179565429688, -13.935067176818848,
317                   3.991722345352173,   6.965037822723389,   26.08910369873047,
318                   -32.330623626708984, 19.467453002929688,  -1.9826143980026245,
319                   -2.221067190170288,  -0.5990060567855835, 0.48625022172927856,
320                   2.0611159801483154,  1.9674323797225952,  21.36834716796875,
321                   23.404233932495117,  8.070624351501465,   -5.446018218994141,
322                   7.367972373962402,   -4.729177951812744,  -0.9264468550682068,
323                   4.8575029373168945,  3.852308988571167,   7.08862829208374,
324                   3.6926915645599365,  4.091310024261475,   5.271382808685303,
325                   4.439114570617676,   4.361649990081787,   -9.77307415008545,
326                   -4.5006842613220215, -2.757089614868164,  0.25477901101112366,
327                   -1.1027240753173828, -1.8145684003829956, -13.21955680847168,
328                   10.867557525634766,  -14.547454833984375, -8.435402870178223,
329                   45.407405853271484,  -1.4743067026138306, -12.566932678222656,
330                   43.569156646728516,  27.821678161621094,  0.45854052901268005,
331                   3.4103615283966064,  2.5930423736572266,  5.672616004943848,
332                   5.645579814910889,   22.59735870361328,   5.76314115524292,
333                   6.116993427276611,   24.63783073425293,   15.926804542541504,
334                   3.1268203258514404,  -1.1270453929901123, 7.744210720062256,
335                   0.3513677716255188,  1.2478822469711304});
336   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
337   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
338   op_native_batch_norm_legit_no_training_out(
339       input,
340       weight,
341       bias,
342       running_mean,
343       running_var,
344       momentum,
345       eps,
346       out0,
347       out1,
348       out2);
349   EXPECT_TENSOR_CLOSE(out0, out0_expected);
350   EXPECT_TENSOR_CLOSE(out1, out1_expected);
351   EXPECT_TENSOR_CLOSE(out2, out2_expected);
352 }
353 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest4D)354 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest4D) {
355   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
356 
357   exec_aten::Tensor input = tfFloat.make(
358       {2, 4, 5, 5},
359       {8.0573148727417,     2.2901253700256348,  2.783101797103882,
360        2.095468044281006,   6.389344215393066,   6.702191352844238,
361        6.535638809204102,   1.8584740161895752,  6.037202835083008,
362        7.588045120239258,   0.7384824752807617,  7.876931190490723,
363        2.198972225189209,   0.8259981870651245,  8.311962127685547,
364        8.748727798461914,   6.331905841827393,   1.6120970249176025,
365        7.9793596267700195,  9.730956077575684,   8.96406078338623,
366        7.34755802154541,    1.0760420560836792,  6.761768341064453,
367        3.18643856048584,    0.32129645347595215, 5.146165370941162,
368        1.9008630514144897,  3.1616015434265137,  3.077312707901001,
369        7.684902667999268,   3.1405091285705566,  6.699800491333008,
370        0.7976526021957397,  1.5945738554000854,  0.7354140281677246,
371        9.370306015014648,   4.1550726890563965,  4.169681549072266,
372        5.389268398284912,   6.883472442626953,   8.881608963012695,
373        7.600193500518799,   8.894989967346191,   1.7032986879348755,
374        8.945396423339844,   1.6370415687561035,  7.708703994750977,
375        7.488667964935303,   7.315606594085693,   5.349757194519043,
376        6.913224220275879,   3.6051642894744873,  3.8086843490600586,
377        3.2311654090881348,  4.91132926940918,    1.331128478050232,
378        2.73335337638855,    0.46345293521881104, 8.168035507202148,
379        8.112630844116211,   9.38737678527832,    8.532957077026367,
380        8.641634941101074,   7.772867679595947,   3.7504279613494873,
381        1.1857783794403076,  7.61868953704834,    9.75157642364502,
382        3.6754441261291504,  2.468808174133301,   6.380059719085693,
383        6.197269439697266,   7.659857273101807,   6.72884464263916,
384        9.320260047912598,   1.9144713878631592,  6.228992462158203,
385        2.7658307552337646,  6.0448317527771,     1.1033517122268677,
386        7.482324600219727,   4.140635013580322,   0.4461771249771118,
387        9.729606628417969,   7.259793758392334,   7.154001235961914,
388        8.320201873779297,   0.8773839473724365,  6.855964660644531,
389        4.737044334411621,   4.0600152015686035,  6.474225044250488,
390        0.8523398637771606,  3.7826621532440186,  5.399431228637695,
391        0.17764925956726074, 5.480880260467529,   1.5790224075317383,
392        7.965246200561523,   0.919603705406189,   6.623161315917969,
393        6.618031978607178,   1.6051316261291504,  0.07815778255462646,
394        7.8453497886657715,  2.781987190246582,   0.28109610080718994,
395        9.149931907653809,   7.448637962341309,   5.52522087097168,
396        4.095173358917236,   6.3080902099609375,  5.314402103424072,
397        8.845094680786133,   6.3725972175598145,  1.9547373056411743,
398        5.2839508056640625,  3.5294246673583984,  3.570653200149536,
399        2.5026822090148926,  0.5656778812408447,  8.309356689453125,
400        0.7813519239425659,  2.366170883178711,   9.322799682617188,
401        0.5455368757247925,  0.7133877277374268,  6.577077388763428,
402        8.393207550048828,   5.753355979919434,   7.874646186828613,
403        6.351865768432617,   7.233908176422119,   7.866637706756592,
404        5.024176120758057,   5.872377872467041,   0.3430730104446411,
405        1.7413997650146484,  7.130331993103027,   7.7794294357299805,
406        8.817843437194824,   4.551261901855469,   4.685880661010742,
407        0.4518568515777588,  3.2571589946746826,  9.467324256896973,
408        6.947274208068848,   1.1890357732772827,  4.438136100769043,
409        0.790744423866272,   0.9745275974273682,  2.3840129375457764,
410        9.280584335327148,   7.309266090393066,   6.359057903289795,
411        4.779758930206299,   6.523046970367432,   2.581796169281006,
412        7.4173126220703125,  5.556275844573975,   6.3515143394470215,
413        9.909261703491211,   4.264077663421631,   1.5390598773956299,
414        6.409996032714844,   9.431000709533691,   6.966275215148926,
415        6.593939781188965,   9.72049331665039,    8.224472045898438,
416        1.1502748727798462,  9.417522430419922,   2.0071351528167725,
417        7.99619722366333,    5.217411518096924,   0.5482637882232666,
418        3.6407017707824707,  9.56554889678955,    5.932462215423584,
419        8.26833724975586,    2.5603179931640625,  7.974213600158691,
420        6.683809280395508,   5.0010175704956055,  8.93687915802002,
421        4.7291178703308105,  1.1585253477096558,  2.50417423248291,
422        3.685148239135742,   0.36632418632507324, 7.834067344665527,
423        9.173870086669922,   3.781676769256592,   5.6734232902526855,
424        3.301741600036621,   1.3799077272415161,  8.990988731384277,
425        2.2520315647125244,  2.483280897140503});
426   exec_aten::optional<exec_aten::Tensor> weight =
427       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
428           {4},
429           {1.8311285972595215,
430            5.851841926574707,
431            6.108979225158691,
432            5.1755266189575195}));
433   exec_aten::optional<exec_aten::Tensor> bias =
434       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
435           {4},
436           {5.1375732421875,
437            3.7950849533081055,
438            2.406358242034912,
439            5.785604476928711}));
440   exec_aten::Tensor running_mean = tfFloat.make(
441       {4},
442       {2.8203158378601074,
443        3.1786017417907715,
444        1.9189423322677612,
445        1.8829244375228882});
446   exec_aten::Tensor running_var = tfFloat.make(
447       {4},
448       {1.4411485195159912,
449        7.426868438720703,
450        7.584629535675049,
451        5.526189804077148});
452   double momentum = 0.1;
453   double eps = 0;
454   exec_aten::Tensor out0 = tfFloat.zeros({2, 4, 5, 5});
455   exec_aten::Tensor out1 = tfFloat.zeros({0});
456   exec_aten::Tensor out2 = tfFloat.zeros({0});
457   exec_aten::Tensor out0_expected = tfFloat.make(
458       {2, 4, 5, 5},
459       {13.125737190246582,   4.328856468200684,    5.080809593200684,
460        4.031939506530762,    10.581527709960938,   11.058723449707031,
461        10.804675102233887,   3.6704447269439697,   10.044395446777344,
462        12.409944534301758,   1.962085485458374,    12.850592613220215,
463        4.189817905426025,    2.095576047897339,    13.514159202575684,
464        14.180371284484863,   10.493914604187012,   3.29463791847229,
465        13.006829261779785,   15.678596496582031,   14.50882625579834,
466        12.043122291564941,   2.476976156234741,    11.149598121643066,
467        5.6960320472717285,   -2.340364456176758,   8.020005226135254,
468        1.0514155626296997,   3.7585806846618652,   3.5775885581970215,
469        13.47139835357666,    3.713289260864258,    11.35610294342041,
470        -1.3174920082092285,  0.3937252461910248,   -1.4511359930038452,
471        17.09044075012207,    5.891846656799316,    5.923215866088867,
472        8.542016983032227,    11.75049877166748,    16.041067123413086,
473        13.28950309753418,    16.069801330566406,   0.6271884441375732,
474        16.178037643432617,   0.48491552472114563,  13.522506713867188,
475        13.050026893615723,   12.678414344787598,   10.01660442352295,
476        13.48469352722168,    6.146742343902588,    6.598191261291504,
477        5.317136287689209,    9.044082641601562,    1.1024672985076904,
478        4.212887763977051,    -0.8222138285636902,  16.26811981201172,
479        16.145221710205078,   18.972867965698242,   17.077590942382812,
480        17.318660736083984,   15.391557693481445,   6.468966484069824,
481        0.7800511717796326,   15.049558639526367,   19.780736923217773,
482        6.302637100219727,    3.626072645187378,    12.30202579498291,
483        11.896559715270996,   15.140877723693848,   13.075701713562012,
484        22.15976333618164,    5.855058670043945,    15.353979110717773,
485        7.729425430297852,    14.948527336120605,   4.069284439086914,
486        18.11333465576172,    10.756217002868652,   2.6224381923675537,
487        23.06098747253418,    17.6234073638916,     17.390493392944336,
488        19.958019256591797,   3.5717902183532715,   16.734331130981445,
489        12.069281578063965,   10.578722953796387,   15.89388656616211,
490        3.516652822494507,    9.968097686767578,    13.527603149414062,
491        2.031242847442627,    13.70692253112793,    5.1165289878845215,
492        19.176542282104492,   2.2383556365966797,   10.938176155090332,
493        10.930352210998535,   3.284013509750366,    0.9548709392547607,
494        12.802419662475586,   5.079109191894531,    1.2644193172454834,
495        14.792341232299805,   12.19730281829834,    9.263452529907227,
496        7.082154750823975,    10.457588195800781,   8.94188404083252,
497        14.327363014221191,   10.55598258972168,    3.8172783851623535,
498        8.895435333251953,    6.2191996574401855,   6.282087326049805,
499        4.653076171875,       1.6985011100769043,   13.510184288024902,
500        2.027475595474243,    4.444851398468018,    16.98843002319336,
501        -1.8588563203811646,  -1.4984327554702759,  11.092581748962402,
502        14.992330551147461,   9.323816299438477,    13.87883186340332,
503        10.608987808227539,   12.502984046936035,   13.861635208129883,
504        7.758059501647949,    9.579390525817871,    -2.2936041355133057,
505        0.7090023756027222,   12.280576705932617,   13.6743745803833,
506        15.904145240783691,   6.742578029632568,    7.031642436981201,
507        -2.060014009475708,   3.9637696743011475,   17.298765182495117,
508        11.887499809265137,   -0.47708070278167725, 6.499664306640625,
509        -0.09621463716030121, 0.3114539086818695,   3.4379796981811523,
510        18.735980987548828,   14.363194465637207,   12.255439758300781,
511        8.752232551574707,    12.619200706481934,   3.8767030239105225,
512        14.602864265441895,   10.47470474243164,    12.238706588745117,
513        20.13051414489746,    7.608346462249756,    1.5637015104293823,
514        12.368430137634277,   19.06963539123535,    13.602371215820312,
515        12.77645492553711,    19.711788177490234,   16.393308639526367,
516        0.7012971639633179,   19.039737701416016,   2.601987838745117,
517        15.886947631835938,   13.12686538696289,    2.847193956375122,
518        9.655555725097656,    22.69979476928711,    14.701132774353027,
519        19.843833923339844,   7.276965141296387,    19.196285247802734,
520        16.355310440063477,   12.6504487991333,     21.315706253051758,
521        12.051830291748047,   4.190755844116211,    7.153357982635498,
522        9.753409385681152,    2.4466326236724854,   18.887737274169922,
523        21.83746910095215,    9.96592903137207,     14.130828857421875,
524        8.909295082092285,    4.678154945373535,    21.43483543395996,
525        6.598236560821533,    7.107358932495117});
526   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
527   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
528   op_native_batch_norm_legit_no_training_out(
529       input,
530       weight,
531       bias,
532       running_mean,
533       running_var,
534       momentum,
535       eps,
536       out0,
537       out1,
538       out2);
539   EXPECT_TENSOR_CLOSE(out0, out0_expected);
540   EXPECT_TENSOR_CLOSE(out1, out1_expected);
541   EXPECT_TENSOR_CLOSE(out2, out2_expected);
542 }
543 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestDouble)544 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestDouble) {
545   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Double>
546       tfDouble;
547 
548   exec_aten::Tensor input = tfDouble.make(
549       {3, 4, 3, 3},
550       {0.09871780872344971, 5.7593607902526855,  4.542290687561035,
551        9.888419151306152,   4.6276702880859375,  0.23040294647216797,
552        5.160412311553955,   5.192661285400391,   7.774633407592773,
553        3.82037353515625,    6.421841621398926,   1.3372838497161865,
554        5.101180553436279,   3.166962146759033,   0.253373384475708,
555        5.272202491760254,   0.8737403154373169,  9.0341796875,
556        4.930244445800781,   5.145639896392822,   8.51688003540039,
557        1.0039496421813965,  6.3629584312438965,  8.20095157623291,
558        7.129164695739746,   6.775269031524658,   5.83862829208374,
559        4.415182590484619,   9.107303619384766,   1.1548930406570435,
560        4.394702434539795,   7.173308372497559,   1.648862361907959,
561        3.040163516998291,   8.946229934692383,   8.740336418151855,
562        7.152044773101807,   6.766063690185547,   8.682901382446289,
563        2.8464317321777344,  8.757857322692871,   8.097877502441406,
564        5.039367198944092,   1.713152527809143,   1.5446704626083374,
565        7.220646858215332,   5.2453131675720215,  7.095609188079834,
566        6.792170524597168,   5.975555896759033,   3.7161855697631836,
567        2.0132927894592285,  3.0089516639709473,  1.4530837535858154,
568        2.124783515930176,   1.3747084140777588,  0.4398918151855469,
569        5.140370845794678,   0.16295194625854492, 5.689471244812012,
570        9.149665832519531,   9.32123851776123,    1.5971916913986206,
571        3.5363614559173584,  0.4872584342956543,  7.255306243896484,
572        8.349767684936523,   0.977746844291687,   0.010267496109008789,
573        9.964345932006836,   9.955519676208496,   2.3190832138061523,
574        9.237786293029785,   4.200929641723633,   9.231035232543945,
575        3.777331829071045,   4.507022857666016,   9.332846641540527,
576        0.8198702335357666,  0.8076483011245728,  6.062283992767334,
577        5.735506057739258,   6.782886505126953,   6.669310569763184,
578        5.708680152893066,   7.5679931640625,     4.829475402832031,
579        1.1562585830688477,  0.5352389812469482,  4.793148040771484,
580        1.7251378297805786,  9.661691665649414,   7.695187568664551,
581        2.569558620452881,   5.02672004699707,    4.213432312011719,
582        0.4719752073287964,  3.2524518966674805,  4.827580451965332,
583        1.7936384677886963,  1.8733304738998413,  9.386192321777344,
584        2.442445755004883,   2.2374587059020996,  1.6268903017044067,
585        1.9272565841674805,  0.04978537559509277, 5.165012359619141});
586   exec_aten::optional<exec_aten::Tensor> weight =
587       exec_aten::optional<exec_aten::Tensor>(tfDouble.make(
588           {4},
589           {5.4100823402404785,
590            3.3440847396850586,
591            0.9714162349700928,
592            0.6811875104904175}));
593   exec_aten::optional<exec_aten::Tensor> bias =
594       exec_aten::optional<exec_aten::Tensor>(tfDouble.make(
595           {4},
596           {6.839208126068115,
597            6.471728801727295,
598            3.077871799468994,
599            4.0067667961120605}));
600   exec_aten::Tensor running_mean = tfDouble.make(
601       {4},
602       {8.781468391418457,
603        5.093882083892822,
604        9.076446533203125,
605        7.148240089416504});
606   exec_aten::Tensor running_var = tfDouble.make(
607       {4},
608       {1.0133814811706543,
609        2.674386978149414,
610        6.866252422332764,
611        9.597100257873535});
612   double momentum = 0.1;
613   double eps = 0;
614   exec_aten::Tensor out0 = tfDouble.zeros({3, 4, 3, 3});
615   exec_aten::Tensor out1 = tfDouble.zeros({0});
616   exec_aten::Tensor out2 = tfDouble.zeros({0});
617   exec_aten::Tensor out0_expected = tfDouble.make(
618       {3, 4, 3, 3},
619       {-39.82401348817106,   -9.402336001242755,   -15.94316789328793,
620        12.788231783114975,   -15.48431707375971,   -39.11630540562901,
621        -12.621231365199568,  -12.447917505830254,  1.4282310938746887,
622        3.8675726975224554,   9.187229957306027,    -1.2100164305929255,
623        6.486653204105122,    2.53143305610451,     -3.4264695963512772,
624        6.836370389027959,    -2.1579014884157313,  14.529114884449125,
625        1.540793678645762,    1.6206449805000125,   2.870429566122108,
626        0.08523948439124736,  2.07192874490631,     2.7533087138107413,
627        2.3559763770206743,   2.2247803400083974,   1.8775493181436829,
628        3.4058069855539905,   4.437536528655581,    2.688916473373198,
629        3.401303695505533,    4.012278948950884,    2.797533181889745,
630        3.1034601808347273,   4.402118755309629,    4.356845749255535,
631        -1.9177122392967219,  -3.992068820138994,   6.30948495370292,
632        -25.057127980919624,  6.712316477851606,    3.165423782960796,
633        -13.271757354773879,  -31.14764712696726,   -32.05311088198505,
634        10.820680738292133,   6.781385286813924,    10.564995283913401,
635        9.94450345561426,     8.274634831663992,    3.654522124495513,
636        0.1723322067015646,   2.2083225723239543,   -0.9732209832222724,
637        0.5007544998438981,   0.2226870048863787,   -0.12386777243752074,
638        1.6186916404984417,   -0.22653479260151907, 1.822253886542661,
639        3.10501562425824,     3.1686209708035005,   0.3051659025884432,
640        3.212566930139726,    2.542113280715248,    4.0303090947788895,
641        4.270965334366817,    2.6499645871979483,   2.437229873043587,
642        4.625987736153007,    4.624046970174033,    2.9449050525314044,
643        9.29157194403522,     -17.777725500427557,  9.25529009664348,
644        -20.05424357138422,   -16.132705822147077,  9.802449466893506,
645        -35.948364280449816,  -36.01404792933805,   -7.774352749114295,
646        7.783764743499358,    9.92551886693686,     9.69327114024373,
647        7.728909325429857,    11.53095787270343,    5.931052201511217,
648        -1.5801890954114401,  -2.850091828706626,   5.856767563411625,
649        0.3525980358528258,   3.2948336043782858,   2.565812114767675,
650        0.6656413209484225,   1.5765590689152436,   1.2750574158293795,
651        -0.11197383213610211, 0.9188032005673643,   1.5027341303183273,
652        2.829367353398184,    2.8468904728035618,   4.498860120209044,
653        2.972030690911762,    2.9269570039352497,   2.792701843675069,
654        2.858748044414565,    2.4459192831196264,   3.570683705559329});
655   exec_aten::Tensor out1_expected = tfDouble.make({0}, {});
656   exec_aten::Tensor out2_expected = tfDouble.make({0}, {});
657   op_native_batch_norm_legit_no_training_out(
658       input,
659       weight,
660       bias,
661       running_mean,
662       running_var,
663       momentum,
664       eps,
665       out0,
666       out1,
667       out2);
668   EXPECT_TENSOR_CLOSE(out0, out0_expected);
669   EXPECT_TENSOR_CLOSE(out1, out1_expected);
670   EXPECT_TENSOR_CLOSE(out2, out2_expected);
671 }
672 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestNoWeight)673 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestNoWeight) {
674   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
675 
676   exec_aten::Tensor input = tfFloat.make(
677       {4, 7, 5}, {4.1944355964660645,  3.537543296813965,  5.067144393920898,
678                   9.735533714294434,   2.661299228668213,  0.43786585330963135,
679                   8.926244735717773,   8.796754837036133,  2.2966713905334473,
680                   7.153128623962402,   7.055768013000488,  0.3383845090866089,
681                   0.8306580781936646,  2.355782985687256,  5.922069072723389,
682                   1.9597464799880981,  2.731785774230957,  3.488309383392334,
683                   9.926213264465332,   4.582781791687012,  1.2061834335327148,
684                   9.317821502685547,   2.9511327743530273, 1.7717409133911133,
685                   6.329389572143555,   0.844573974609375,  4.269064903259277,
686                   3.9711995124816895,  0.7241052389144897, 2.239838123321533,
687                   2.2850823402404785,  8.232909202575684,  5.126026153564453,
688                   0.09984314441680908, 4.0997748374938965, 8.717041969299316,
689                   2.4102187156677246,  8.769938468933105,  9.614383697509766,
690                   4.630570411682129,   7.450488090515137,  2.7233500480651855,
691                   5.878231525421143,   1.5304350852966309, 4.100255489349365,
692                   3.448119640350342,   1.356201171875,     7.190479278564453,
693                   4.431788444519043,   9.268322944641113,  7.564930438995361,
694                   5.517428398132324,   6.40336799621582,   1.5203499794006348,
695                   8.397398948669434,   9.415580749511719,  9.271242141723633,
696                   6.522747993469238,   9.739391326904297,  3.8692879676818848,
697                   4.59047794342041,    0.6365865468978882, 4.950358867645264,
698                   2.111414670944214,   3.189572811126709,  2.893986701965332,
699                   9.007704734802246,   1.0862338542938232, 4.761219024658203,
700                   0.5109339952468872,  4.226720333099365,  9.338176727294922,
701                   9.641677856445312,   8.222650527954102,  3.068296432495117,
702                   3.6851234436035156,  2.7459187507629395, 9.115739822387695,
703                   3.6909985542297363,  6.9336957931518555, 7.548684597015381,
704                   9.266566276550293,   4.114157676696777,  1.0546678304672241,
705                   1.881745457649231,   4.227387428283691,  1.3194853067398071,
706                   6.739812850952148,   6.846013069152832,  7.290800094604492,
707                   2.164156436920166,   3.4476895332336426, 7.013863563537598,
708                   6.375678062438965,   2.4389731884002686, 5.257430553436279,
709                   0.5499267578125,     5.771737098693848,  5.308223247528076,
710                   0.2141815423965454,  5.413756370544434,  1.757289171218872,
711                   9.780686378479004,   4.005618095397949,  7.078739166259766,
712                   4.428859710693359,   2.348038673400879,  4.718813419342041,
713                   1.896933913230896,   4.842776775360107,  6.077881813049316,
714                   5.315243721008301,   5.951466083526611,  7.1189398765563965,
715                   4.036149024963379,   9.996458053588867,  0.9982073307037354,
716                   1.865202784538269,   0.5543112754821777, 4.5034308433532715,
717                   4.392091751098633,   9.904728889465332,  2.8027725219726562,
718                   8.39471435546875,    7.3801398277282715, 3.346047878265381,
719                   1.2300896644592285,  6.925620079040527,  4.869058132171631,
720                   0.06555616855621338, 2.475562572479248,  0.5495405197143555,
721                   6.707937240600586,   0.946076512336731,  6.623589515686035,
722                   5.87992000579834,    2.196932315826416,  8.085456848144531,
723                   7.774395942687988,   8.86058235168457});
724   exec_aten::optional<exec_aten::Tensor> weight;
725   exec_aten::optional<exec_aten::Tensor> bias =
726       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
727           {7},
728           {3.2798612117767334,
729            7.070205211639404,
730            0.8457618951797485,
731            8.21817684173584,
732            4.158933162689209,
733            9.13807201385498,
734            5.7105536460876465}));
735   exec_aten::Tensor running_mean = tfFloat.make(
736       {7},
737       {8.596701622009277,
738        8.133163452148438,
739        1.8364977836608887,
740        9.756494522094727,
741        6.470483779907227,
742        6.9614739418029785,
743        5.237721920013428});
744   exec_aten::Tensor running_var = tfFloat.make(
745       {7},
746       {2.258641242980957,
747        0.8535522222518921,
748        9.372869491577148,
749        8.911684036254883,
750        9.814156532287598,
751        0.5796539783477783,
752        5.289167881011963});
753   double momentum = 0.1;
754   double eps = 0;
755   exec_aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
756   exec_aten::Tensor out1 = tfFloat.zeros({0});
757   exec_aten::Tensor out2 = tfFloat.zeros({0});
758   exec_aten::Tensor out0_expected = tfFloat.make(
759       {4, 7, 5}, {0.3506367802619934,  -0.08645286411046982, 0.9313285946846008,
760                   4.037628650665283,   -0.669497013092041,   -1.259130597114563,
761                   7.928630828857422,   7.788471698760986,    0.7528274059295654,
762                   6.009422302246094,   2.5505621433258057,   0.3564245402812958,
763                   0.5172187089920044,  1.0153789520263672,   2.180255651473999,
764                   5.606414794921875,   5.865033149719238,    6.1184539794921875,
765                   8.275029182434082,   6.485081672668457,    2.478527545928955,
766                   5.067825794219971,   3.0355288982391357,   2.659057855606079,
767                   4.113894939422607,   1.1037919521331787,   5.601710796356201,
768                   5.210477828979492,   0.9455615878105164,   2.9364101886749268,
769                   4.426696300506592,   7.012911796569824,    5.661986351013184,
770                   3.47651743888855,    5.215754985809326,    3.3599345684051514,
771                   -0.8365635275840759, 3.3951313495635986,   3.957016706466675,
772                   0.6408365964889526,  6.331282138824463,    1.2146613597869873,
773                   4.629482746124268,   -0.07654135674238205, 2.7050139904022217,
774                   1.3721752166748047,  0.6888798475265503,   2.5945637226104736,
775                   1.6934765577316284,  3.273261785507202,    7.484044551849365,
776                   6.798170566558838,   7.094943046569824,    5.459225177764893,
777                   7.762905597686768,   5.0990309715271,      5.052957057952881,
778                   4.175616264343262,   5.202394008636475,    3.328611373901367,
779                   6.0238728523254395,  0.8306096196174622,   6.496560573577881,
780                   2.7677316665649414,  4.183845043182373,    4.691458225250244,
781                   7.34980583190918,    3.90541672706604,     5.50336217880249,
782                   3.655266761779785,   0.37211874127388,     3.7732315063476562,
783                   3.9751780033111572,  3.0309712886810303,   -0.398685097694397,
784                   2.255678176879883,   1.2390896081924438,   8.13373851776123,
785                   2.2620372772216797,  5.771909713745117,    2.711566209793091,
786                   3.2726879119873047,  1.5897270441055298,   0.590388298034668,
787                   0.8605414032936096,  6.366031169891357,    5.391939163208008,
788                   7.207645893096924,   7.243220806121826,    7.392216205596924,
789                   2.784320116043091,   3.1940338611602783,   4.33238410949707,
790                   4.128670692443848,   2.8720436096191406,   6.899885654449463,
791                   0.716785728931427,   7.575404644012451,    6.966599464416504,
792                   0.27579912543296814, 5.7870965003967285,   4.197203159332275,
793                   7.685911178588867,   5.174814224243164,    6.511058807373047,
794                   0.5066202878952026,  -0.8779374957084656,  0.699552595615387,
795                   -1.178098201751709,  0.7820366024971008,   4.845582962036133,
796                   4.020108699798584,   4.708751201629639,    5.972416877746582,
797                   2.6356256008148193,  3.511096715927124,    0.5719462633132935,
798                   0.8551380038261414,  0.4269539415836334,   1.7168775796890259,
799                   6.421204090118408,   8.26783275604248,     5.88881254196167,
800                   7.7620062828063965,  7.422143459320068,    3.1615889072418213,
801                   2.486158609390259,   4.304216384887695,    3.6477456092834473,
802                   2.1144304275512695,  3.2460241317749023,   0.7162784337997437,
803                   8.805062294006348,   1.2371110916137695,   8.694275856018066,
804                   5.989792346954346,   4.388367176055908,    6.94879674911499,
805                   6.813542366027832,   7.285834312438965});
806   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
807   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
808   op_native_batch_norm_legit_no_training_out(
809       input,
810       weight,
811       bias,
812       running_mean,
813       running_var,
814       momentum,
815       eps,
816       out0,
817       out1,
818       out2);
819   EXPECT_TENSOR_CLOSE(out0, out0_expected);
820   EXPECT_TENSOR_CLOSE(out1, out1_expected);
821   EXPECT_TENSOR_CLOSE(out2, out2_expected);
822 }
823 
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestNoWeightNoBias)824 TEST_F(
825     OpNativeBatchNormLegitNoTrainingOutTest,
826     SampleAtomicTestNoWeightNoBias) {
827   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
828 
829   exec_aten::Tensor input = tfFloat.make(
830       {2, 4, 2, 2},
831       {2.628833770751953,   7.391754150390625,  9.153281211853027,
832        2.480319023132324,   6.5120697021484375, 5.680999755859375,
833        9.440492630004883,   8.139138221740723,  5.618698596954346,
834        0.21270036697387695, 8.981918334960938,  8.472748756408691,
835        2.5718064308166504,  5.815331935882568,  0.08409619331359863,
836        2.942138195037842,   1.8946051597595215, 9.46719741821289,
837        0.5490684509277344,  2.2121663093566895, 5.5882368087768555,
838        9.131031036376953,   5.822923183441162,  3.371715545654297,
839        0.1542043685913086,  3.606675863265991,  2.65787410736084,
840        5.136600494384766,   6.950716972351074,  6.051759719848633,
841        7.304986953735352,   6.186429977416992});
842   exec_aten::optional<exec_aten::Tensor> weight;
843   exec_aten::optional<exec_aten::Tensor> bias;
844   exec_aten::Tensor running_mean = tfFloat.make(
845       {4},
846       {8.043643951416016,
847        3.569627285003662,
848        7.6375412940979,
849        4.194377899169922});
850   exec_aten::Tensor running_var = tfFloat.make(
851       {4},
852       {7.512979507446289,
853        0.0478285551071167,
854        0.8684122562408447,
855        1.9676220417022705});
856   double momentum = 0.1;
857   double eps = 0;
858   exec_aten::Tensor out0 = tfFloat.zeros({2, 4, 2, 2});
859   exec_aten::Tensor out1 = tfFloat.zeros({0});
860   exec_aten::Tensor out2 = tfFloat.zeros({0});
861   exec_aten::Tensor out0_expected = tfFloat.make(
862       {2, 4, 2, 2},
863       {-1.975500464439392,  -0.23783083260059357, 0.40483206510543823,
864        -2.0296835899353027, 13.454400062561035,   9.65431022644043,
865        26.844696044921875,  20.894216537475586,   -2.1664047241210938,
866        -7.967539310455322,  1.4426401853561401,   0.896254301071167,
867        -1.1567326784133911, 1.1555795669555664,   -2.9302234649658203,
868        -0.8927228450775146, -2.2433712482452393,  0.5193589925765991,
869        -2.734267234802246,  -2.127514362335205,   9.230148315429688,
870        25.42967414855957,   10.303258895874023,   -0.9049565196037292,
871        -8.03031063079834,   -4.325490474700928,   -5.343642234802246,
872        -2.6837403774261475, 1.9649964570999146,   1.3241291046142578,
873        2.2175559997558594,  1.4201356172561646});
874   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
875   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
876   op_native_batch_norm_legit_no_training_out(
877       input,
878       weight,
879       bias,
880       running_mean,
881       running_var,
882       momentum,
883       eps,
884       out0,
885       out1,
886       out2);
887   EXPECT_TENSOR_CLOSE(out0, out0_expected);
888   EXPECT_TENSOR_CLOSE(out1, out1_expected);
889   EXPECT_TENSOR_CLOSE(out2, out2_expected);
890 }
891 
TEST_F(OpNativeBatchNormLegitOutTest,SampleAtomicTest2D)892 TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
893   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
894 
895   exec_aten::Tensor input = tfFloat.make(
896       {4, 7}, {2.876736640930176,  7.67944860458374,   5.701690196990967,
897                9.299789428710938,  3.023690700531006,  5.315116882324219,
898                7.185585021972656,  6.911304473876953,  7.61051082611084,
899                1.4963287115097046, 0.7381612062454224, 8.588483810424805,
900                6.583977699279785,  8.831110000610352,  0.8165055513381958,
901                7.087201118469238,  5.572513580322266,  4.446897983551025,
902                4.444573402404785,  6.254056930541992,  5.906398296356201,
903                9.971039772033691,  3.5423521995544434, 7.452159881591797,
904                9.93700122833252,   1.8560808897018433, 1.524025797843933,
905                7.3222975730896});
906   exec_aten::optional<exec_aten::Tensor> weight =
907       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
908           {7},
909           {8.287437438964844,
910            8.227645874023438,
911            6.65926456451416,
912            9.436124801635742,
913            4.119281768798828,
914            8.593960762023926,
915            2.3760855197906494}));
916   exec_aten::optional<exec_aten::Tensor> bias =
917       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
918           {7},
919           {7.824275970458984,
920            6.84327507019043,
921            8.354326248168945,
922            8.773970603942871,
923            3.89609694480896,
924            3.0753469467163086,
925            3.1105971336364746}));
926   exec_aten::Tensor running_mean = tfFloat.make(
927       {7},
928       {9.700226783752441,
929        0.1234668493270874,
930        7.527220249176025,
931        8.993252754211426,
932        0.4736626148223877,
933        7.7135701179504395,
934        5.12320613861084});
935   exec_aten::Tensor running_var = tfFloat.make(
936       {7},
937       {3.585531234741211,
938        6.615292549133301,
939        0.24084866046905518,
940        5.175800323486328,
941        0.5886000394821167,
942        6.23909854888916,
943        1.5029621124267578});
944   bool training = false;
945   double momentum = 0.1;
946   double eps = 0;
947   exec_aten::Tensor out0 = tfFloat.zeros({4, 7});
948   exec_aten::Tensor out1 = tfFloat.zeros({0});
949   exec_aten::Tensor out2 = tfFloat.zeros({0});
950   exec_aten::Tensor out0_expected = tfFloat.make(
951       {4, 7}, {-22.039867401123047, 31.014127731323242,  -16.416650772094727,
952                10.04538631439209,   17.5877628326416,    -5.17673921585083,
953                7.1078033447265625,  -4.381907939910889,  30.793603897094727,
954                -73.48003387451172,  -25.46548080444336,  47.46636962890625,
955                -0.8111140131950378, 10.29708194732666,   -31.056814193725586,
956                29.119586944580078,  -18.16947364807129,  -10.082839965820312,
957                25.216796875,        -1.9462348222732544, 4.628543376922607,
958                9.00953483581543,    17.779958724975586,  7.335818767547607,
959                12.688335418701172,  11.318607330322266,  -18.22031593322754,
960                7.372773170471191});
961   exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
962   exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
963   op_native_batch_norm_legit_out(
964       input,
965       weight,
966       bias,
967       running_mean,
968       running_var,
969       training,
970       momentum,
971       eps,
972       out0,
973       out1,
974       out2);
975   EXPECT_TENSOR_CLOSE(out0, out0_expected);
976   EXPECT_TENSOR_CLOSE(out1, out1_expected);
977   EXPECT_TENSOR_CLOSE(out2, out2_expected);
978 }
979 
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest2D)980 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
981   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
982 
983   exec_aten::Tensor input =
984       tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
985   exec_aten::optional<exec_aten::Tensor> weight =
986       exec_aten::optional<exec_aten::Tensor>();
987   exec_aten::optional<exec_aten::Tensor> bias =
988       exec_aten::optional<exec_aten::Tensor>();
989   bool training = true;
990   double momentum = 1e-3;
991   double eps = 1e-5;
992   exec_aten::Tensor out0 = tfFloat.zeros({3, 4});
993   exec_aten::Tensor out1 = tfFloat.zeros({4});
994   exec_aten::Tensor out2 = tfFloat.zeros({4});
995   exec_aten::Tensor out0_expected = tfFloat.make(
996       {3, 4},
997       {-0.98058063,
998        -1.03422451,
999        -1.06904495,
1000        -1.09332705,
1001        -0.39223224,
1002        -0.31822300,
1003        -0.26726127,
1004        -0.23017406,
1005        1.37281299,
1006        1.35244739,
1007        1.33630610,
1008        1.32350123});
1009   exec_aten::Tensor out1_expected =
1010       tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
1011   exec_aten::Tensor out2_expected =
1012       tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
1013   op_native_batch_norm_legit_no_stats_out(
1014       input, weight, bias, training, momentum, eps, out0, out1, out2);
1015   EXPECT_TENSOR_CLOSE(out0, out0_expected);
1016   EXPECT_TENSOR_CLOSE(out1, out1_expected);
1017   EXPECT_TENSOR_CLOSE(out2, out2_expected);
1018 }
1019 
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest3D)1020 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {
1021   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1022 
1023   exec_aten::Tensor input = tfFloat.make(
1024       {2, 3, 4}, {0,   1,   4,   9,   16,  25,  36,  49,  64,  81,  100, 121,
1025                   144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529});
1026   exec_aten::optional<exec_aten::Tensor> weight =
1027       exec_aten::optional<exec_aten::Tensor>();
1028   exec_aten::optional<exec_aten::Tensor> bias =
1029       exec_aten::optional<exec_aten::Tensor>();
1030   bool training = true;
1031   double momentum = 1e-3;
1032   double eps = 1e-5;
1033   exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 4});
1034   exec_aten::Tensor out1 = tfFloat.zeros({3});
1035   exec_aten::Tensor out2 = tfFloat.zeros({3});
1036   exec_aten::Tensor out0_expected = tfFloat.make(
1037       {2, 3, 4},
1038       {-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884,
1039        -1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889,
1040        -0.93578988, -0.82507670, 0.54575467,  0.81593025,  1.10771990,
1041        1.42112350,  0.61339414,  0.84740579,  1.09560001,  1.35797679,
1042        0.64582670,  0.86198103,  1.08867943,  1.32592189});
1043   exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1044   exec_aten::Tensor out2_expected =
1045       tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1046   op_native_batch_norm_legit_no_stats_out(
1047       input, weight, bias, training, momentum, eps, out0, out1, out2);
1048   EXPECT_TENSOR_CLOSE(out0, out0_expected);
1049   EXPECT_TENSOR_CLOSE(out1, out1_expected);
1050   EXPECT_TENSOR_CLOSE(out2, out2_expected);
1051 }
1052 
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest4D)1053 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) {
1054   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1055 
1056   exec_aten::Tensor input =
1057       tfFloat.make({2, 3, 2, 2}, {0,   1,   4,   9,   16,  25,  36,  49,
1058                                   64,  81,  100, 121, 144, 169, 196, 225,
1059                                   256, 289, 324, 361, 400, 441, 484, 529});
1060   exec_aten::optional<exec_aten::Tensor> weight =
1061       exec_aten::optional<exec_aten::Tensor>(
1062           tfFloat.make({3}, {1.1, 0.7, 0.3}));
1063   exec_aten::optional<exec_aten::Tensor> bias =
1064       exec_aten::optional<exec_aten::Tensor>(
1065           tfFloat.make({3}, {1.7, 2.2, 3.3}));
1066   bool training = true;
1067   double momentum = 1e-3;
1068   double eps = 1e-5;
1069   exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2});
1070   exec_aten::Tensor out1 = tfFloat.zeros({3});
1071   exec_aten::Tensor out2 = tfFloat.zeros({3});
1072   exec_aten::Tensor out0_expected = tfFloat.make(
1073       {2, 3, 2, 2},
1074       {0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883,
1075        1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688,
1076        2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404,
1077        2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656});
1078   exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1079   exec_aten::Tensor out2_expected =
1080       tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1081   op_native_batch_norm_legit_no_stats_out(
1082       input, weight, bias, training, momentum, eps, out0, out1, out2);
1083   EXPECT_TENSOR_CLOSE(out0, out0_expected);
1084   EXPECT_TENSOR_CLOSE(out1, out1_expected);
1085   EXPECT_TENSOR_CLOSE(out2, out2_expected);
1086 }
1087