1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19
20 class OpNativeBatchNormLegitNoTrainingOutTest : public OperatorTest {
21 protected:
22 ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_no_training_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,const exec_aten::Tensor & running_mean,const exec_aten::Tensor & running_var,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)23 op_native_batch_norm_legit_no_training_out(
24 const exec_aten::Tensor& input,
25 const exec_aten::optional<exec_aten::Tensor>& weight,
26 const exec_aten::optional<exec_aten::Tensor>& bias,
27 const exec_aten::Tensor& running_mean,
28 const exec_aten::Tensor& running_var,
29 double momentum,
30 double eps,
31 exec_aten::Tensor& out0,
32 exec_aten::Tensor& out1,
33 exec_aten::Tensor& out2) {
34 return torch::executor::aten::_native_batch_norm_legit_no_training_outf(
35 context_,
36 input,
37 weight,
38 bias,
39 running_mean,
40 running_var,
41 momentum,
42 eps,
43 out0,
44 out1,
45 out2);
46 }
47 };
48
49 class OpNativeBatchNormLegitOutTest : public OperatorTest {
50 protected:
51 ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,exec_aten::Tensor & running_mean,exec_aten::Tensor & running_var,bool training,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)52 op_native_batch_norm_legit_out(
53 const exec_aten::Tensor& input,
54 const exec_aten::optional<exec_aten::Tensor>& weight,
55 const exec_aten::optional<exec_aten::Tensor>& bias,
56 exec_aten::Tensor& running_mean,
57 exec_aten::Tensor& running_var,
58 bool training,
59 double momentum,
60 double eps,
61 exec_aten::Tensor& out0,
62 exec_aten::Tensor& out1,
63 exec_aten::Tensor& out2) {
64 executorch::runtime::KernelRuntimeContext context{};
65 return torch::executor::aten::_native_batch_norm_legit_outf(
66 context,
67 input,
68 weight,
69 bias,
70 running_mean,
71 running_var,
72 training,
73 momentum,
74 eps,
75 out0,
76 out1,
77 out2);
78 }
79 };
80
81 class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
82 protected:
83 ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_no_stats_out(const exec_aten::Tensor & input,const exec_aten::optional<exec_aten::Tensor> & weight,const exec_aten::optional<exec_aten::Tensor> & bias,bool training,double momentum,double eps,exec_aten::Tensor & out0,exec_aten::Tensor & out1,exec_aten::Tensor & out2)84 op_native_batch_norm_legit_no_stats_out(
85 const exec_aten::Tensor& input,
86 const exec_aten::optional<exec_aten::Tensor>& weight,
87 const exec_aten::optional<exec_aten::Tensor>& bias,
88 bool training,
89 double momentum,
90 double eps,
91 exec_aten::Tensor& out0,
92 exec_aten::Tensor& out1,
93 exec_aten::Tensor& out2) {
94 return torch::executor::aten::_native_batch_norm_legit_outf(
95 context_,
96 input,
97 weight,
98 bias,
99 training,
100 momentum,
101 eps,
102 out0,
103 out1,
104 out2);
105 }
106 };
107
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest2D)108 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
109 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
110
111 exec_aten::Tensor input = tfFloat.make(
112 {4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967,
113 9.299789428710938, 3.023690700531006, 5.315116882324219,
114 7.185585021972656, 6.911304473876953, 7.61051082611084,
115 1.4963287115097046, 0.7381612062454224, 8.588483810424805,
116 6.583977699279785, 8.831110000610352, 0.8165055513381958,
117 7.087201118469238, 5.572513580322266, 4.446897983551025,
118 4.444573402404785, 6.254056930541992, 5.906398296356201,
119 9.971039772033691, 3.5423521995544434, 7.452159881591797,
120 9.93700122833252, 1.8560808897018433, 1.524025797843933,
121 7.3222975730896});
122 exec_aten::optional<exec_aten::Tensor> weight =
123 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
124 {7},
125 {8.287437438964844,
126 8.227645874023438,
127 6.65926456451416,
128 9.436124801635742,
129 4.119281768798828,
130 8.593960762023926,
131 2.3760855197906494}));
132 exec_aten::optional<exec_aten::Tensor> bias =
133 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
134 {7},
135 {7.824275970458984,
136 6.84327507019043,
137 8.354326248168945,
138 8.773970603942871,
139 3.89609694480896,
140 3.0753469467163086,
141 3.1105971336364746}));
142 exec_aten::Tensor running_mean = tfFloat.make(
143 {7},
144 {9.700226783752441,
145 0.1234668493270874,
146 7.527220249176025,
147 8.993252754211426,
148 0.4736626148223877,
149 7.7135701179504395,
150 5.12320613861084});
151 exec_aten::Tensor running_var = tfFloat.make(
152 {7},
153 {3.585531234741211,
154 6.615292549133301,
155 0.24084866046905518,
156 5.175800323486328,
157 0.5886000394821167,
158 6.23909854888916,
159 1.5029621124267578});
160 double momentum = 0.1;
161 double eps = 0;
162 exec_aten::Tensor out0 = tfFloat.zeros({4, 7});
163 exec_aten::Tensor out1 = tfFloat.zeros({0});
164 exec_aten::Tensor out2 = tfFloat.zeros({0});
165 exec_aten::Tensor out0_expected = tfFloat.make(
166 {4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727,
167 10.04538631439209, 17.5877628326416, -5.17673921585083,
168 7.1078033447265625, -4.381907939910889, 30.793603897094727,
169 -73.48003387451172, -25.46548080444336, 47.46636962890625,
170 -0.8111140131950378, 10.29708194732666, -31.056814193725586,
171 29.119586944580078, -18.16947364807129, -10.082839965820312,
172 25.216796875, -1.9462348222732544, 4.628543376922607,
173 9.00953483581543, 17.779958724975586, 7.335818767547607,
174 12.688335418701172, 11.318607330322266, -18.22031593322754,
175 7.372773170471191});
176 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
177 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
178 op_native_batch_norm_legit_no_training_out(
179 input,
180 weight,
181 bias,
182 running_mean,
183 running_var,
184 momentum,
185 eps,
186 out0,
187 out1,
188 out2);
189 EXPECT_TENSOR_CLOSE(out0, out0_expected);
190 EXPECT_TENSOR_CLOSE(out1, out1_expected);
191 EXPECT_TENSOR_CLOSE(out2, out2_expected);
192 }
193
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest3D)194 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest3D) {
195 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
196
197 exec_aten::Tensor input = tfFloat.make(
198 {4, 7, 5}, {5.277339935302734, 5.94276237487793, 6.543086051940918,
199 2.411264181137085, 8.980886459350586, 2.7123653888702393,
200 9.466896057128906, 9.324702262878418, 1.9848430156707764,
201 8.388091087341309, 1.5069717168807983, 5.350819110870361,
202 1.727534532546997, 4.913003444671631, 2.555372714996338,
203 4.321412563323975, 1.107364296913147, 6.048641681671143,
204 9.496582984924316, 0.9668296575546265, 0.8103430271148682,
205 8.187652587890625, 9.455179214477539, 0.5739009380340576,
206 3.550161838531494, 1.5362483263015747, 7.338945388793945,
207 3.583885431289673, 6.5086517333984375, 0.9027481079101562,
208 0.8805221319198608, 3.983092784881592, 5.43976354598999,
209 9.080245971679688, 2.602390766143799, 2.1537625789642334,
210 3.2551045417785645, 7.098634719848633, 8.135055541992188,
211 7.457048416137695, 5.3438568115234375, 3.7822632789611816,
212 3.4284191131591797, 6.144853115081787, 9.79615592956543,
213 5.735219955444336, 2.5468051433563232, 8.514262199401855,
214 3.775507926940918, 8.327726364135742, 4.772212505340576,
215 7.100861072540283, 3.477612018585205, 9.359293937683105,
216 5.203947067260742, 3.6150975227355957, 6.159048557281494,
217 0.9919929504394531, 1.6809028387069702, 0.3627735376358032,
218 1.8791186809539795, 4.037001132965088, 8.129783630371094,
219 4.79802131652832, 2.9911656379699707, 8.659820556640625,
220 7.378345489501953, 3.6833512783050537, 2.4555420875549316,
221 8.481515884399414, 3.733121156692505, 6.075705528259277,
222 6.900073051452637, 6.380939960479736, 3.204977512359619,
223 2.058135986328125, 4.60728120803833, 7.737727165222168,
224 5.3178815841674805, 9.224492073059082, 4.838874340057373,
225 2.717348337173462, 1.8555694818496704, 1.856197714805603,
226 7.189084053039551, 5.280246257781982, 7.550882816314697,
227 0.6145977973937988, 6.764681816101074, 4.217874526977539,
228 0.89302659034729, 2.4634499549865723, 3.51415753364563,
229 5.038887977600098, 4.948186874389648, 8.326996803283691,
230 8.919670104980469, 4.45585298538208, 0.5209791660308838,
231 4.2513017654418945, 0.047875046730041504, 2.453791618347168,
232 6.113187789916992, 5.47722053527832, 7.524778842926025,
233 0.3724473714828491, 2.6570069789886475, 9.420238494873047,
234 4.650344371795654, 4.206380844116211, 1.2107867002487183,
235 3.3689606189727783, 4.082674980163574, 5.31553840637207,
236 4.759864807128906, 5.461820602416992, 2.0690488815307617,
237 9.234517097473145, 1.6740238666534424, 3.492245674133301,
238 9.844581604003906, 4.278226852416992, 2.9611783027648926,
239 9.626322746276855, 7.756594657897949, 3.4873299598693848,
240 6.345180988311768, 5.55388069152832, 8.535417556762695,
241 8.509242057800293, 8.684778213500977, 3.784114122390747,
242 3.887125253677368, 9.278786659240723, 6.742891311645508,
243 5.01821756362915, 2.326876640319824, 7.939553737640381,
244 3.2622408866882324, 3.829448699951172});
245 exec_aten::optional<exec_aten::Tensor> weight =
246 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
247 {7},
248 {0.5193436145782471,
249 4.531304836273193,
250 8.960723876953125,
251 8.598731994628906,
252 2.6848177909851074,
253 7.309220314025879,
254 2.2476916313171387}));
255 exec_aten::optional<exec_aten::Tensor> bias =
256 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
257 {7},
258 {4.643010139465332,
259 0.2791440486907959,
260 3.6721653938293457,
261 3.918765068054199,
262 2.6499342918395996,
263 5.721188545227051,
264 5.901060104370117}));
265 exec_aten::Tensor running_mean = tfFloat.make(
266 {7},
267 {5.818909645080566,
268 5.325511932373047,
269 7.094021797180176,
270 4.9185566902160645,
271 5.608961582183838,
272 3.7719011306762695,
273 6.7734270095825195});
274 exec_aten::Tensor running_var = tfFloat.make(
275 {7},
276 {8.8593168258667,
277 3.440363883972168,
278 7.105681896209717,
279 1.0423260927200317,
280 6.756608009338379,
281 4.527579307556152,
282 2.022289752960205});
283 double momentum = 0.1;
284 double eps = 0;
285 exec_aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
286 exec_aten::Tensor out1 = tfFloat.zeros({0});
287 exec_aten::Tensor out2 = tfFloat.zeros({0});
288 exec_aten::Tensor out0_expected = tfFloat.make(
289 {4, 7, 5}, {4.5485148429870605, 4.664620399475098, 4.76936674118042,
290 4.048431873321533, 5.194723129272461, -6.104737281799316,
291 10.396490097045898, 10.049112319946289, -7.8820648193359375,
292 7.760983943939209, -15.109009742736816, -2.1877059936523438,
293 -14.367575645446777, -3.659447431564331, -11.584752082824707,
294 -1.1105821132659912, -28.180377960205078, 13.436722755432129,
295 42.476444244384766, -29.3640079498291, -2.306469440460205,
296 5.313416481018066, 6.622621059417725, -2.5506861209869385,
297 0.5234383940696716, -1.9584782123565674, 17.97430419921875,
298 5.075337886810303, 15.122170448303223, -4.134607791900635,
299 -3.413116931915283, 1.4907281398773193, 3.793105363845825,
300 9.547160148620605, -0.69157475233078, 4.003501892089844,
301 4.1956682205200195, 4.8663010597229, 5.047139644622803,
302 4.92883825302124, 0.3239605128765106, -3.4909913539886475,
303 -4.3554277420043945, 2.2807836532592773, 11.20086669921875,
304 -0.8955214619636536, -11.613553047180176, 8.446381568908691,
305 -7.483201026916504, 7.819331169128418, 2.686206579208374,
306 22.29886817932129, -8.217354774475098, 41.320152282714844,
307 6.322420597076416, 0.5905092358589172, 3.218108892440796,
308 -2.1188466548919678, -1.4072843790054321, -2.7687556743621826,
309 -0.7806879878044128, 6.63183069229126, 20.690902709960938,
310 9.246002197265625, 3.039292335510254, 8.882646560668945,
311 6.857179164886475, 1.016964316368103, -0.9236800670623779,
312 8.600822448730469, 4.279074192047119, 4.687816619873047,
313 4.831655502319336, 4.741075038909912, 4.1869215965271,
314 -7.7030110359191895, -1.475483775138855, 6.172153472900391,
315 0.2605033814907074, 9.804300308227539, -3.9086363315582275,
316 -11.040262222290039, -13.937179565429688, -13.935067176818848,
317 3.991722345352173, 6.965037822723389, 26.08910369873047,
318 -32.330623626708984, 19.467453002929688, -1.9826143980026245,
319 -2.221067190170288, -0.5990060567855835, 0.48625022172927856,
320 2.0611159801483154, 1.9674323797225952, 21.36834716796875,
321 23.404233932495117, 8.070624351501465, -5.446018218994141,
322 7.367972373962402, -4.729177951812744, -0.9264468550682068,
323 4.8575029373168945, 3.852308988571167, 7.08862829208374,
324 3.6926915645599365, 4.091310024261475, 5.271382808685303,
325 4.439114570617676, 4.361649990081787, -9.77307415008545,
326 -4.5006842613220215, -2.757089614868164, 0.25477901101112366,
327 -1.1027240753173828, -1.8145684003829956, -13.21955680847168,
328 10.867557525634766, -14.547454833984375, -8.435402870178223,
329 45.407405853271484, -1.4743067026138306, -12.566932678222656,
330 43.569156646728516, 27.821678161621094, 0.45854052901268005,
331 3.4103615283966064, 2.5930423736572266, 5.672616004943848,
332 5.645579814910889, 22.59735870361328, 5.76314115524292,
333 6.116993427276611, 24.63783073425293, 15.926804542541504,
334 3.1268203258514404, -1.1270453929901123, 7.744210720062256,
335 0.3513677716255188, 1.2478822469711304});
336 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
337 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
338 op_native_batch_norm_legit_no_training_out(
339 input,
340 weight,
341 bias,
342 running_mean,
343 running_var,
344 momentum,
345 eps,
346 out0,
347 out1,
348 out2);
349 EXPECT_TENSOR_CLOSE(out0, out0_expected);
350 EXPECT_TENSOR_CLOSE(out1, out1_expected);
351 EXPECT_TENSOR_CLOSE(out2, out2_expected);
352 }
353
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTest4D)354 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest4D) {
355 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
356
357 exec_aten::Tensor input = tfFloat.make(
358 {2, 4, 5, 5},
359 {8.0573148727417, 2.2901253700256348, 2.783101797103882,
360 2.095468044281006, 6.389344215393066, 6.702191352844238,
361 6.535638809204102, 1.8584740161895752, 6.037202835083008,
362 7.588045120239258, 0.7384824752807617, 7.876931190490723,
363 2.198972225189209, 0.8259981870651245, 8.311962127685547,
364 8.748727798461914, 6.331905841827393, 1.6120970249176025,
365 7.9793596267700195, 9.730956077575684, 8.96406078338623,
366 7.34755802154541, 1.0760420560836792, 6.761768341064453,
367 3.18643856048584, 0.32129645347595215, 5.146165370941162,
368 1.9008630514144897, 3.1616015434265137, 3.077312707901001,
369 7.684902667999268, 3.1405091285705566, 6.699800491333008,
370 0.7976526021957397, 1.5945738554000854, 0.7354140281677246,
371 9.370306015014648, 4.1550726890563965, 4.169681549072266,
372 5.389268398284912, 6.883472442626953, 8.881608963012695,
373 7.600193500518799, 8.894989967346191, 1.7032986879348755,
374 8.945396423339844, 1.6370415687561035, 7.708703994750977,
375 7.488667964935303, 7.315606594085693, 5.349757194519043,
376 6.913224220275879, 3.6051642894744873, 3.8086843490600586,
377 3.2311654090881348, 4.91132926940918, 1.331128478050232,
378 2.73335337638855, 0.46345293521881104, 8.168035507202148,
379 8.112630844116211, 9.38737678527832, 8.532957077026367,
380 8.641634941101074, 7.772867679595947, 3.7504279613494873,
381 1.1857783794403076, 7.61868953704834, 9.75157642364502,
382 3.6754441261291504, 2.468808174133301, 6.380059719085693,
383 6.197269439697266, 7.659857273101807, 6.72884464263916,
384 9.320260047912598, 1.9144713878631592, 6.228992462158203,
385 2.7658307552337646, 6.0448317527771, 1.1033517122268677,
386 7.482324600219727, 4.140635013580322, 0.4461771249771118,
387 9.729606628417969, 7.259793758392334, 7.154001235961914,
388 8.320201873779297, 0.8773839473724365, 6.855964660644531,
389 4.737044334411621, 4.0600152015686035, 6.474225044250488,
390 0.8523398637771606, 3.7826621532440186, 5.399431228637695,
391 0.17764925956726074, 5.480880260467529, 1.5790224075317383,
392 7.965246200561523, 0.919603705406189, 6.623161315917969,
393 6.618031978607178, 1.6051316261291504, 0.07815778255462646,
394 7.8453497886657715, 2.781987190246582, 0.28109610080718994,
395 9.149931907653809, 7.448637962341309, 5.52522087097168,
396 4.095173358917236, 6.3080902099609375, 5.314402103424072,
397 8.845094680786133, 6.3725972175598145, 1.9547373056411743,
398 5.2839508056640625, 3.5294246673583984, 3.570653200149536,
399 2.5026822090148926, 0.5656778812408447, 8.309356689453125,
400 0.7813519239425659, 2.366170883178711, 9.322799682617188,
401 0.5455368757247925, 0.7133877277374268, 6.577077388763428,
402 8.393207550048828, 5.753355979919434, 7.874646186828613,
403 6.351865768432617, 7.233908176422119, 7.866637706756592,
404 5.024176120758057, 5.872377872467041, 0.3430730104446411,
405 1.7413997650146484, 7.130331993103027, 7.7794294357299805,
406 8.817843437194824, 4.551261901855469, 4.685880661010742,
407 0.4518568515777588, 3.2571589946746826, 9.467324256896973,
408 6.947274208068848, 1.1890357732772827, 4.438136100769043,
409 0.790744423866272, 0.9745275974273682, 2.3840129375457764,
410 9.280584335327148, 7.309266090393066, 6.359057903289795,
411 4.779758930206299, 6.523046970367432, 2.581796169281006,
412 7.4173126220703125, 5.556275844573975, 6.3515143394470215,
413 9.909261703491211, 4.264077663421631, 1.5390598773956299,
414 6.409996032714844, 9.431000709533691, 6.966275215148926,
415 6.593939781188965, 9.72049331665039, 8.224472045898438,
416 1.1502748727798462, 9.417522430419922, 2.0071351528167725,
417 7.99619722366333, 5.217411518096924, 0.5482637882232666,
418 3.6407017707824707, 9.56554889678955, 5.932462215423584,
419 8.26833724975586, 2.5603179931640625, 7.974213600158691,
420 6.683809280395508, 5.0010175704956055, 8.93687915802002,
421 4.7291178703308105, 1.1585253477096558, 2.50417423248291,
422 3.685148239135742, 0.36632418632507324, 7.834067344665527,
423 9.173870086669922, 3.781676769256592, 5.6734232902526855,
424 3.301741600036621, 1.3799077272415161, 8.990988731384277,
425 2.2520315647125244, 2.483280897140503});
426 exec_aten::optional<exec_aten::Tensor> weight =
427 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
428 {4},
429 {1.8311285972595215,
430 5.851841926574707,
431 6.108979225158691,
432 5.1755266189575195}));
433 exec_aten::optional<exec_aten::Tensor> bias =
434 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
435 {4},
436 {5.1375732421875,
437 3.7950849533081055,
438 2.406358242034912,
439 5.785604476928711}));
440 exec_aten::Tensor running_mean = tfFloat.make(
441 {4},
442 {2.8203158378601074,
443 3.1786017417907715,
444 1.9189423322677612,
445 1.8829244375228882});
446 exec_aten::Tensor running_var = tfFloat.make(
447 {4},
448 {1.4411485195159912,
449 7.426868438720703,
450 7.584629535675049,
451 5.526189804077148});
452 double momentum = 0.1;
453 double eps = 0;
454 exec_aten::Tensor out0 = tfFloat.zeros({2, 4, 5, 5});
455 exec_aten::Tensor out1 = tfFloat.zeros({0});
456 exec_aten::Tensor out2 = tfFloat.zeros({0});
457 exec_aten::Tensor out0_expected = tfFloat.make(
458 {2, 4, 5, 5},
459 {13.125737190246582, 4.328856468200684, 5.080809593200684,
460 4.031939506530762, 10.581527709960938, 11.058723449707031,
461 10.804675102233887, 3.6704447269439697, 10.044395446777344,
462 12.409944534301758, 1.962085485458374, 12.850592613220215,
463 4.189817905426025, 2.095576047897339, 13.514159202575684,
464 14.180371284484863, 10.493914604187012, 3.29463791847229,
465 13.006829261779785, 15.678596496582031, 14.50882625579834,
466 12.043122291564941, 2.476976156234741, 11.149598121643066,
467 5.6960320472717285, -2.340364456176758, 8.020005226135254,
468 1.0514155626296997, 3.7585806846618652, 3.5775885581970215,
469 13.47139835357666, 3.713289260864258, 11.35610294342041,
470 -1.3174920082092285, 0.3937252461910248, -1.4511359930038452,
471 17.09044075012207, 5.891846656799316, 5.923215866088867,
472 8.542016983032227, 11.75049877166748, 16.041067123413086,
473 13.28950309753418, 16.069801330566406, 0.6271884441375732,
474 16.178037643432617, 0.48491552472114563, 13.522506713867188,
475 13.050026893615723, 12.678414344787598, 10.01660442352295,
476 13.48469352722168, 6.146742343902588, 6.598191261291504,
477 5.317136287689209, 9.044082641601562, 1.1024672985076904,
478 4.212887763977051, -0.8222138285636902, 16.26811981201172,
479 16.145221710205078, 18.972867965698242, 17.077590942382812,
480 17.318660736083984, 15.391557693481445, 6.468966484069824,
481 0.7800511717796326, 15.049558639526367, 19.780736923217773,
482 6.302637100219727, 3.626072645187378, 12.30202579498291,
483 11.896559715270996, 15.140877723693848, 13.075701713562012,
484 22.15976333618164, 5.855058670043945, 15.353979110717773,
485 7.729425430297852, 14.948527336120605, 4.069284439086914,
486 18.11333465576172, 10.756217002868652, 2.6224381923675537,
487 23.06098747253418, 17.6234073638916, 17.390493392944336,
488 19.958019256591797, 3.5717902183532715, 16.734331130981445,
489 12.069281578063965, 10.578722953796387, 15.89388656616211,
490 3.516652822494507, 9.968097686767578, 13.527603149414062,
491 2.031242847442627, 13.70692253112793, 5.1165289878845215,
492 19.176542282104492, 2.2383556365966797, 10.938176155090332,
493 10.930352210998535, 3.284013509750366, 0.9548709392547607,
494 12.802419662475586, 5.079109191894531, 1.2644193172454834,
495 14.792341232299805, 12.19730281829834, 9.263452529907227,
496 7.082154750823975, 10.457588195800781, 8.94188404083252,
497 14.327363014221191, 10.55598258972168, 3.8172783851623535,
498 8.895435333251953, 6.2191996574401855, 6.282087326049805,
499 4.653076171875, 1.6985011100769043, 13.510184288024902,
500 2.027475595474243, 4.444851398468018, 16.98843002319336,
501 -1.8588563203811646, -1.4984327554702759, 11.092581748962402,
502 14.992330551147461, 9.323816299438477, 13.87883186340332,
503 10.608987808227539, 12.502984046936035, 13.861635208129883,
504 7.758059501647949, 9.579390525817871, -2.2936041355133057,
505 0.7090023756027222, 12.280576705932617, 13.6743745803833,
506 15.904145240783691, 6.742578029632568, 7.031642436981201,
507 -2.060014009475708, 3.9637696743011475, 17.298765182495117,
508 11.887499809265137, -0.47708070278167725, 6.499664306640625,
509 -0.09621463716030121, 0.3114539086818695, 3.4379796981811523,
510 18.735980987548828, 14.363194465637207, 12.255439758300781,
511 8.752232551574707, 12.619200706481934, 3.8767030239105225,
512 14.602864265441895, 10.47470474243164, 12.238706588745117,
513 20.13051414489746, 7.608346462249756, 1.5637015104293823,
514 12.368430137634277, 19.06963539123535, 13.602371215820312,
515 12.77645492553711, 19.711788177490234, 16.393308639526367,
516 0.7012971639633179, 19.039737701416016, 2.601987838745117,
517 15.886947631835938, 13.12686538696289, 2.847193956375122,
518 9.655555725097656, 22.69979476928711, 14.701132774353027,
519 19.843833923339844, 7.276965141296387, 19.196285247802734,
520 16.355310440063477, 12.6504487991333, 21.315706253051758,
521 12.051830291748047, 4.190755844116211, 7.153357982635498,
522 9.753409385681152, 2.4466326236724854, 18.887737274169922,
523 21.83746910095215, 9.96592903137207, 14.130828857421875,
524 8.909295082092285, 4.678154945373535, 21.43483543395996,
525 6.598236560821533, 7.107358932495117});
526 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
527 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
528 op_native_batch_norm_legit_no_training_out(
529 input,
530 weight,
531 bias,
532 running_mean,
533 running_var,
534 momentum,
535 eps,
536 out0,
537 out1,
538 out2);
539 EXPECT_TENSOR_CLOSE(out0, out0_expected);
540 EXPECT_TENSOR_CLOSE(out1, out1_expected);
541 EXPECT_TENSOR_CLOSE(out2, out2_expected);
542 }
543
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestDouble)544 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestDouble) {
545 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Double>
546 tfDouble;
547
548 exec_aten::Tensor input = tfDouble.make(
549 {3, 4, 3, 3},
550 {0.09871780872344971, 5.7593607902526855, 4.542290687561035,
551 9.888419151306152, 4.6276702880859375, 0.23040294647216797,
552 5.160412311553955, 5.192661285400391, 7.774633407592773,
553 3.82037353515625, 6.421841621398926, 1.3372838497161865,
554 5.101180553436279, 3.166962146759033, 0.253373384475708,
555 5.272202491760254, 0.8737403154373169, 9.0341796875,
556 4.930244445800781, 5.145639896392822, 8.51688003540039,
557 1.0039496421813965, 6.3629584312438965, 8.20095157623291,
558 7.129164695739746, 6.775269031524658, 5.83862829208374,
559 4.415182590484619, 9.107303619384766, 1.1548930406570435,
560 4.394702434539795, 7.173308372497559, 1.648862361907959,
561 3.040163516998291, 8.946229934692383, 8.740336418151855,
562 7.152044773101807, 6.766063690185547, 8.682901382446289,
563 2.8464317321777344, 8.757857322692871, 8.097877502441406,
564 5.039367198944092, 1.713152527809143, 1.5446704626083374,
565 7.220646858215332, 5.2453131675720215, 7.095609188079834,
566 6.792170524597168, 5.975555896759033, 3.7161855697631836,
567 2.0132927894592285, 3.0089516639709473, 1.4530837535858154,
568 2.124783515930176, 1.3747084140777588, 0.4398918151855469,
569 5.140370845794678, 0.16295194625854492, 5.689471244812012,
570 9.149665832519531, 9.32123851776123, 1.5971916913986206,
571 3.5363614559173584, 0.4872584342956543, 7.255306243896484,
572 8.349767684936523, 0.977746844291687, 0.010267496109008789,
573 9.964345932006836, 9.955519676208496, 2.3190832138061523,
574 9.237786293029785, 4.200929641723633, 9.231035232543945,
575 3.777331829071045, 4.507022857666016, 9.332846641540527,
576 0.8198702335357666, 0.8076483011245728, 6.062283992767334,
577 5.735506057739258, 6.782886505126953, 6.669310569763184,
578 5.708680152893066, 7.5679931640625, 4.829475402832031,
579 1.1562585830688477, 0.5352389812469482, 4.793148040771484,
580 1.7251378297805786, 9.661691665649414, 7.695187568664551,
581 2.569558620452881, 5.02672004699707, 4.213432312011719,
582 0.4719752073287964, 3.2524518966674805, 4.827580451965332,
583 1.7936384677886963, 1.8733304738998413, 9.386192321777344,
584 2.442445755004883, 2.2374587059020996, 1.6268903017044067,
585 1.9272565841674805, 0.04978537559509277, 5.165012359619141});
586 exec_aten::optional<exec_aten::Tensor> weight =
587 exec_aten::optional<exec_aten::Tensor>(tfDouble.make(
588 {4},
589 {5.4100823402404785,
590 3.3440847396850586,
591 0.9714162349700928,
592 0.6811875104904175}));
593 exec_aten::optional<exec_aten::Tensor> bias =
594 exec_aten::optional<exec_aten::Tensor>(tfDouble.make(
595 {4},
596 {6.839208126068115,
597 6.471728801727295,
598 3.077871799468994,
599 4.0067667961120605}));
600 exec_aten::Tensor running_mean = tfDouble.make(
601 {4},
602 {8.781468391418457,
603 5.093882083892822,
604 9.076446533203125,
605 7.148240089416504});
606 exec_aten::Tensor running_var = tfDouble.make(
607 {4},
608 {1.0133814811706543,
609 2.674386978149414,
610 6.866252422332764,
611 9.597100257873535});
612 double momentum = 0.1;
613 double eps = 0;
614 exec_aten::Tensor out0 = tfDouble.zeros({3, 4, 3, 3});
615 exec_aten::Tensor out1 = tfDouble.zeros({0});
616 exec_aten::Tensor out2 = tfDouble.zeros({0});
617 exec_aten::Tensor out0_expected = tfDouble.make(
618 {3, 4, 3, 3},
619 {-39.82401348817106, -9.402336001242755, -15.94316789328793,
620 12.788231783114975, -15.48431707375971, -39.11630540562901,
621 -12.621231365199568, -12.447917505830254, 1.4282310938746887,
622 3.8675726975224554, 9.187229957306027, -1.2100164305929255,
623 6.486653204105122, 2.53143305610451, -3.4264695963512772,
624 6.836370389027959, -2.1579014884157313, 14.529114884449125,
625 1.540793678645762, 1.6206449805000125, 2.870429566122108,
626 0.08523948439124736, 2.07192874490631, 2.7533087138107413,
627 2.3559763770206743, 2.2247803400083974, 1.8775493181436829,
628 3.4058069855539905, 4.437536528655581, 2.688916473373198,
629 3.401303695505533, 4.012278948950884, 2.797533181889745,
630 3.1034601808347273, 4.402118755309629, 4.356845749255535,
631 -1.9177122392967219, -3.992068820138994, 6.30948495370292,
632 -25.057127980919624, 6.712316477851606, 3.165423782960796,
633 -13.271757354773879, -31.14764712696726, -32.05311088198505,
634 10.820680738292133, 6.781385286813924, 10.564995283913401,
635 9.94450345561426, 8.274634831663992, 3.654522124495513,
636 0.1723322067015646, 2.2083225723239543, -0.9732209832222724,
637 0.5007544998438981, 0.2226870048863787, -0.12386777243752074,
638 1.6186916404984417, -0.22653479260151907, 1.822253886542661,
639 3.10501562425824, 3.1686209708035005, 0.3051659025884432,
640 3.212566930139726, 2.542113280715248, 4.0303090947788895,
641 4.270965334366817, 2.6499645871979483, 2.437229873043587,
642 4.625987736153007, 4.624046970174033, 2.9449050525314044,
643 9.29157194403522, -17.777725500427557, 9.25529009664348,
644 -20.05424357138422, -16.132705822147077, 9.802449466893506,
645 -35.948364280449816, -36.01404792933805, -7.774352749114295,
646 7.783764743499358, 9.92551886693686, 9.69327114024373,
647 7.728909325429857, 11.53095787270343, 5.931052201511217,
648 -1.5801890954114401, -2.850091828706626, 5.856767563411625,
649 0.3525980358528258, 3.2948336043782858, 2.565812114767675,
650 0.6656413209484225, 1.5765590689152436, 1.2750574158293795,
651 -0.11197383213610211, 0.9188032005673643, 1.5027341303183273,
652 2.829367353398184, 2.8468904728035618, 4.498860120209044,
653 2.972030690911762, 2.9269570039352497, 2.792701843675069,
654 2.858748044414565, 2.4459192831196264, 3.570683705559329});
655 exec_aten::Tensor out1_expected = tfDouble.make({0}, {});
656 exec_aten::Tensor out2_expected = tfDouble.make({0}, {});
657 op_native_batch_norm_legit_no_training_out(
658 input,
659 weight,
660 bias,
661 running_mean,
662 running_var,
663 momentum,
664 eps,
665 out0,
666 out1,
667 out2);
668 EXPECT_TENSOR_CLOSE(out0, out0_expected);
669 EXPECT_TENSOR_CLOSE(out1, out1_expected);
670 EXPECT_TENSOR_CLOSE(out2, out2_expected);
671 }
672
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestNoWeight)673 TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestNoWeight) {
674 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
675
676 exec_aten::Tensor input = tfFloat.make(
677 {4, 7, 5}, {4.1944355964660645, 3.537543296813965, 5.067144393920898,
678 9.735533714294434, 2.661299228668213, 0.43786585330963135,
679 8.926244735717773, 8.796754837036133, 2.2966713905334473,
680 7.153128623962402, 7.055768013000488, 0.3383845090866089,
681 0.8306580781936646, 2.355782985687256, 5.922069072723389,
682 1.9597464799880981, 2.731785774230957, 3.488309383392334,
683 9.926213264465332, 4.582781791687012, 1.2061834335327148,
684 9.317821502685547, 2.9511327743530273, 1.7717409133911133,
685 6.329389572143555, 0.844573974609375, 4.269064903259277,
686 3.9711995124816895, 0.7241052389144897, 2.239838123321533,
687 2.2850823402404785, 8.232909202575684, 5.126026153564453,
688 0.09984314441680908, 4.0997748374938965, 8.717041969299316,
689 2.4102187156677246, 8.769938468933105, 9.614383697509766,
690 4.630570411682129, 7.450488090515137, 2.7233500480651855,
691 5.878231525421143, 1.5304350852966309, 4.100255489349365,
692 3.448119640350342, 1.356201171875, 7.190479278564453,
693 4.431788444519043, 9.268322944641113, 7.564930438995361,
694 5.517428398132324, 6.40336799621582, 1.5203499794006348,
695 8.397398948669434, 9.415580749511719, 9.271242141723633,
696 6.522747993469238, 9.739391326904297, 3.8692879676818848,
697 4.59047794342041, 0.6365865468978882, 4.950358867645264,
698 2.111414670944214, 3.189572811126709, 2.893986701965332,
699 9.007704734802246, 1.0862338542938232, 4.761219024658203,
700 0.5109339952468872, 4.226720333099365, 9.338176727294922,
701 9.641677856445312, 8.222650527954102, 3.068296432495117,
702 3.6851234436035156, 2.7459187507629395, 9.115739822387695,
703 3.6909985542297363, 6.9336957931518555, 7.548684597015381,
704 9.266566276550293, 4.114157676696777, 1.0546678304672241,
705 1.881745457649231, 4.227387428283691, 1.3194853067398071,
706 6.739812850952148, 6.846013069152832, 7.290800094604492,
707 2.164156436920166, 3.4476895332336426, 7.013863563537598,
708 6.375678062438965, 2.4389731884002686, 5.257430553436279,
709 0.5499267578125, 5.771737098693848, 5.308223247528076,
710 0.2141815423965454, 5.413756370544434, 1.757289171218872,
711 9.780686378479004, 4.005618095397949, 7.078739166259766,
712 4.428859710693359, 2.348038673400879, 4.718813419342041,
713 1.896933913230896, 4.842776775360107, 6.077881813049316,
714 5.315243721008301, 5.951466083526611, 7.1189398765563965,
715 4.036149024963379, 9.996458053588867, 0.9982073307037354,
716 1.865202784538269, 0.5543112754821777, 4.5034308433532715,
717 4.392091751098633, 9.904728889465332, 2.8027725219726562,
718 8.39471435546875, 7.3801398277282715, 3.346047878265381,
719 1.2300896644592285, 6.925620079040527, 4.869058132171631,
720 0.06555616855621338, 2.475562572479248, 0.5495405197143555,
721 6.707937240600586, 0.946076512336731, 6.623589515686035,
722 5.87992000579834, 2.196932315826416, 8.085456848144531,
723 7.774395942687988, 8.86058235168457});
724 exec_aten::optional<exec_aten::Tensor> weight;
725 exec_aten::optional<exec_aten::Tensor> bias =
726 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
727 {7},
728 {3.2798612117767334,
729 7.070205211639404,
730 0.8457618951797485,
731 8.21817684173584,
732 4.158933162689209,
733 9.13807201385498,
734 5.7105536460876465}));
735 exec_aten::Tensor running_mean = tfFloat.make(
736 {7},
737 {8.596701622009277,
738 8.133163452148438,
739 1.8364977836608887,
740 9.756494522094727,
741 6.470483779907227,
742 6.9614739418029785,
743 5.237721920013428});
744 exec_aten::Tensor running_var = tfFloat.make(
745 {7},
746 {2.258641242980957,
747 0.8535522222518921,
748 9.372869491577148,
749 8.911684036254883,
750 9.814156532287598,
751 0.5796539783477783,
752 5.289167881011963});
753 double momentum = 0.1;
754 double eps = 0;
755 exec_aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
756 exec_aten::Tensor out1 = tfFloat.zeros({0});
757 exec_aten::Tensor out2 = tfFloat.zeros({0});
758 exec_aten::Tensor out0_expected = tfFloat.make(
759 {4, 7, 5}, {0.3506367802619934, -0.08645286411046982, 0.9313285946846008,
760 4.037628650665283, -0.669497013092041, -1.259130597114563,
761 7.928630828857422, 7.788471698760986, 0.7528274059295654,
762 6.009422302246094, 2.5505621433258057, 0.3564245402812958,
763 0.5172187089920044, 1.0153789520263672, 2.180255651473999,
764 5.606414794921875, 5.865033149719238, 6.1184539794921875,
765 8.275029182434082, 6.485081672668457, 2.478527545928955,
766 5.067825794219971, 3.0355288982391357, 2.659057855606079,
767 4.113894939422607, 1.1037919521331787, 5.601710796356201,
768 5.210477828979492, 0.9455615878105164, 2.9364101886749268,
769 4.426696300506592, 7.012911796569824, 5.661986351013184,
770 3.47651743888855, 5.215754985809326, 3.3599345684051514,
771 -0.8365635275840759, 3.3951313495635986, 3.957016706466675,
772 0.6408365964889526, 6.331282138824463, 1.2146613597869873,
773 4.629482746124268, -0.07654135674238205, 2.7050139904022217,
774 1.3721752166748047, 0.6888798475265503, 2.5945637226104736,
775 1.6934765577316284, 3.273261785507202, 7.484044551849365,
776 6.798170566558838, 7.094943046569824, 5.459225177764893,
777 7.762905597686768, 5.0990309715271, 5.052957057952881,
778 4.175616264343262, 5.202394008636475, 3.328611373901367,
779 6.0238728523254395, 0.8306096196174622, 6.496560573577881,
780 2.7677316665649414, 4.183845043182373, 4.691458225250244,
781 7.34980583190918, 3.90541672706604, 5.50336217880249,
782 3.655266761779785, 0.37211874127388, 3.7732315063476562,
783 3.9751780033111572, 3.0309712886810303, -0.398685097694397,
784 2.255678176879883, 1.2390896081924438, 8.13373851776123,
785 2.2620372772216797, 5.771909713745117, 2.711566209793091,
786 3.2726879119873047, 1.5897270441055298, 0.590388298034668,
787 0.8605414032936096, 6.366031169891357, 5.391939163208008,
788 7.207645893096924, 7.243220806121826, 7.392216205596924,
789 2.784320116043091, 3.1940338611602783, 4.33238410949707,
790 4.128670692443848, 2.8720436096191406, 6.899885654449463,
791 0.716785728931427, 7.575404644012451, 6.966599464416504,
792 0.27579912543296814, 5.7870965003967285, 4.197203159332275,
793 7.685911178588867, 5.174814224243164, 6.511058807373047,
794 0.5066202878952026, -0.8779374957084656, 0.699552595615387,
795 -1.178098201751709, 0.7820366024971008, 4.845582962036133,
796 4.020108699798584, 4.708751201629639, 5.972416877746582,
797 2.6356256008148193, 3.511096715927124, 0.5719462633132935,
798 0.8551380038261414, 0.4269539415836334, 1.7168775796890259,
799 6.421204090118408, 8.26783275604248, 5.88881254196167,
800 7.7620062828063965, 7.422143459320068, 3.1615889072418213,
801 2.486158609390259, 4.304216384887695, 3.6477456092834473,
802 2.1144304275512695, 3.2460241317749023, 0.7162784337997437,
803 8.805062294006348, 1.2371110916137695, 8.694275856018066,
804 5.989792346954346, 4.388367176055908, 6.94879674911499,
805 6.813542366027832, 7.285834312438965});
806 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
807 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
808 op_native_batch_norm_legit_no_training_out(
809 input,
810 weight,
811 bias,
812 running_mean,
813 running_var,
814 momentum,
815 eps,
816 out0,
817 out1,
818 out2);
819 EXPECT_TENSOR_CLOSE(out0, out0_expected);
820 EXPECT_TENSOR_CLOSE(out1, out1_expected);
821 EXPECT_TENSOR_CLOSE(out2, out2_expected);
822 }
823
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest,SampleAtomicTestNoWeightNoBias)824 TEST_F(
825 OpNativeBatchNormLegitNoTrainingOutTest,
826 SampleAtomicTestNoWeightNoBias) {
827 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
828
829 exec_aten::Tensor input = tfFloat.make(
830 {2, 4, 2, 2},
831 {2.628833770751953, 7.391754150390625, 9.153281211853027,
832 2.480319023132324, 6.5120697021484375, 5.680999755859375,
833 9.440492630004883, 8.139138221740723, 5.618698596954346,
834 0.21270036697387695, 8.981918334960938, 8.472748756408691,
835 2.5718064308166504, 5.815331935882568, 0.08409619331359863,
836 2.942138195037842, 1.8946051597595215, 9.46719741821289,
837 0.5490684509277344, 2.2121663093566895, 5.5882368087768555,
838 9.131031036376953, 5.822923183441162, 3.371715545654297,
839 0.1542043685913086, 3.606675863265991, 2.65787410736084,
840 5.136600494384766, 6.950716972351074, 6.051759719848633,
841 7.304986953735352, 6.186429977416992});
842 exec_aten::optional<exec_aten::Tensor> weight;
843 exec_aten::optional<exec_aten::Tensor> bias;
844 exec_aten::Tensor running_mean = tfFloat.make(
845 {4},
846 {8.043643951416016,
847 3.569627285003662,
848 7.6375412940979,
849 4.194377899169922});
850 exec_aten::Tensor running_var = tfFloat.make(
851 {4},
852 {7.512979507446289,
853 0.0478285551071167,
854 0.8684122562408447,
855 1.9676220417022705});
856 double momentum = 0.1;
857 double eps = 0;
858 exec_aten::Tensor out0 = tfFloat.zeros({2, 4, 2, 2});
859 exec_aten::Tensor out1 = tfFloat.zeros({0});
860 exec_aten::Tensor out2 = tfFloat.zeros({0});
861 exec_aten::Tensor out0_expected = tfFloat.make(
862 {2, 4, 2, 2},
863 {-1.975500464439392, -0.23783083260059357, 0.40483206510543823,
864 -2.0296835899353027, 13.454400062561035, 9.65431022644043,
865 26.844696044921875, 20.894216537475586, -2.1664047241210938,
866 -7.967539310455322, 1.4426401853561401, 0.896254301071167,
867 -1.1567326784133911, 1.1555795669555664, -2.9302234649658203,
868 -0.8927228450775146, -2.2433712482452393, 0.5193589925765991,
869 -2.734267234802246, -2.127514362335205, 9.230148315429688,
870 25.42967414855957, 10.303258895874023, -0.9049565196037292,
871 -8.03031063079834, -4.325490474700928, -5.343642234802246,
872 -2.6837403774261475, 1.9649964570999146, 1.3241291046142578,
873 2.2175559997558594, 1.4201356172561646});
874 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
875 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
876 op_native_batch_norm_legit_no_training_out(
877 input,
878 weight,
879 bias,
880 running_mean,
881 running_var,
882 momentum,
883 eps,
884 out0,
885 out1,
886 out2);
887 EXPECT_TENSOR_CLOSE(out0, out0_expected);
888 EXPECT_TENSOR_CLOSE(out1, out1_expected);
889 EXPECT_TENSOR_CLOSE(out2, out2_expected);
890 }
891
TEST_F(OpNativeBatchNormLegitOutTest,SampleAtomicTest2D)892 TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
893 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
894
895 exec_aten::Tensor input = tfFloat.make(
896 {4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967,
897 9.299789428710938, 3.023690700531006, 5.315116882324219,
898 7.185585021972656, 6.911304473876953, 7.61051082611084,
899 1.4963287115097046, 0.7381612062454224, 8.588483810424805,
900 6.583977699279785, 8.831110000610352, 0.8165055513381958,
901 7.087201118469238, 5.572513580322266, 4.446897983551025,
902 4.444573402404785, 6.254056930541992, 5.906398296356201,
903 9.971039772033691, 3.5423521995544434, 7.452159881591797,
904 9.93700122833252, 1.8560808897018433, 1.524025797843933,
905 7.3222975730896});
906 exec_aten::optional<exec_aten::Tensor> weight =
907 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
908 {7},
909 {8.287437438964844,
910 8.227645874023438,
911 6.65926456451416,
912 9.436124801635742,
913 4.119281768798828,
914 8.593960762023926,
915 2.3760855197906494}));
916 exec_aten::optional<exec_aten::Tensor> bias =
917 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
918 {7},
919 {7.824275970458984,
920 6.84327507019043,
921 8.354326248168945,
922 8.773970603942871,
923 3.89609694480896,
924 3.0753469467163086,
925 3.1105971336364746}));
926 exec_aten::Tensor running_mean = tfFloat.make(
927 {7},
928 {9.700226783752441,
929 0.1234668493270874,
930 7.527220249176025,
931 8.993252754211426,
932 0.4736626148223877,
933 7.7135701179504395,
934 5.12320613861084});
935 exec_aten::Tensor running_var = tfFloat.make(
936 {7},
937 {3.585531234741211,
938 6.615292549133301,
939 0.24084866046905518,
940 5.175800323486328,
941 0.5886000394821167,
942 6.23909854888916,
943 1.5029621124267578});
944 bool training = false;
945 double momentum = 0.1;
946 double eps = 0;
947 exec_aten::Tensor out0 = tfFloat.zeros({4, 7});
948 exec_aten::Tensor out1 = tfFloat.zeros({0});
949 exec_aten::Tensor out2 = tfFloat.zeros({0});
950 exec_aten::Tensor out0_expected = tfFloat.make(
951 {4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727,
952 10.04538631439209, 17.5877628326416, -5.17673921585083,
953 7.1078033447265625, -4.381907939910889, 30.793603897094727,
954 -73.48003387451172, -25.46548080444336, 47.46636962890625,
955 -0.8111140131950378, 10.29708194732666, -31.056814193725586,
956 29.119586944580078, -18.16947364807129, -10.082839965820312,
957 25.216796875, -1.9462348222732544, 4.628543376922607,
958 9.00953483581543, 17.779958724975586, 7.335818767547607,
959 12.688335418701172, 11.318607330322266, -18.22031593322754,
960 7.372773170471191});
961 exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
962 exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
963 op_native_batch_norm_legit_out(
964 input,
965 weight,
966 bias,
967 running_mean,
968 running_var,
969 training,
970 momentum,
971 eps,
972 out0,
973 out1,
974 out2);
975 EXPECT_TENSOR_CLOSE(out0, out0_expected);
976 EXPECT_TENSOR_CLOSE(out1, out1_expected);
977 EXPECT_TENSOR_CLOSE(out2, out2_expected);
978 }
979
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest2D)980 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
981 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
982
983 exec_aten::Tensor input =
984 tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
985 exec_aten::optional<exec_aten::Tensor> weight =
986 exec_aten::optional<exec_aten::Tensor>();
987 exec_aten::optional<exec_aten::Tensor> bias =
988 exec_aten::optional<exec_aten::Tensor>();
989 bool training = true;
990 double momentum = 1e-3;
991 double eps = 1e-5;
992 exec_aten::Tensor out0 = tfFloat.zeros({3, 4});
993 exec_aten::Tensor out1 = tfFloat.zeros({4});
994 exec_aten::Tensor out2 = tfFloat.zeros({4});
995 exec_aten::Tensor out0_expected = tfFloat.make(
996 {3, 4},
997 {-0.98058063,
998 -1.03422451,
999 -1.06904495,
1000 -1.09332705,
1001 -0.39223224,
1002 -0.31822300,
1003 -0.26726127,
1004 -0.23017406,
1005 1.37281299,
1006 1.35244739,
1007 1.33630610,
1008 1.32350123});
1009 exec_aten::Tensor out1_expected =
1010 tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
1011 exec_aten::Tensor out2_expected =
1012 tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
1013 op_native_batch_norm_legit_no_stats_out(
1014 input, weight, bias, training, momentum, eps, out0, out1, out2);
1015 EXPECT_TENSOR_CLOSE(out0, out0_expected);
1016 EXPECT_TENSOR_CLOSE(out1, out1_expected);
1017 EXPECT_TENSOR_CLOSE(out2, out2_expected);
1018 }
1019
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest3D)1020 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {
1021 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1022
1023 exec_aten::Tensor input = tfFloat.make(
1024 {2, 3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121,
1025 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529});
1026 exec_aten::optional<exec_aten::Tensor> weight =
1027 exec_aten::optional<exec_aten::Tensor>();
1028 exec_aten::optional<exec_aten::Tensor> bias =
1029 exec_aten::optional<exec_aten::Tensor>();
1030 bool training = true;
1031 double momentum = 1e-3;
1032 double eps = 1e-5;
1033 exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 4});
1034 exec_aten::Tensor out1 = tfFloat.zeros({3});
1035 exec_aten::Tensor out2 = tfFloat.zeros({3});
1036 exec_aten::Tensor out0_expected = tfFloat.make(
1037 {2, 3, 4},
1038 {-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884,
1039 -1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889,
1040 -0.93578988, -0.82507670, 0.54575467, 0.81593025, 1.10771990,
1041 1.42112350, 0.61339414, 0.84740579, 1.09560001, 1.35797679,
1042 0.64582670, 0.86198103, 1.08867943, 1.32592189});
1043 exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1044 exec_aten::Tensor out2_expected =
1045 tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1046 op_native_batch_norm_legit_no_stats_out(
1047 input, weight, bias, training, momentum, eps, out0, out1, out2);
1048 EXPECT_TENSOR_CLOSE(out0, out0_expected);
1049 EXPECT_TENSOR_CLOSE(out1, out1_expected);
1050 EXPECT_TENSOR_CLOSE(out2, out2_expected);
1051 }
1052
TEST_F(OpNativeBatchNormLegitNoStatsOutTest,SampleAtomicTest4D)1053 TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) {
1054 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1055
1056 exec_aten::Tensor input =
1057 tfFloat.make({2, 3, 2, 2}, {0, 1, 4, 9, 16, 25, 36, 49,
1058 64, 81, 100, 121, 144, 169, 196, 225,
1059 256, 289, 324, 361, 400, 441, 484, 529});
1060 exec_aten::optional<exec_aten::Tensor> weight =
1061 exec_aten::optional<exec_aten::Tensor>(
1062 tfFloat.make({3}, {1.1, 0.7, 0.3}));
1063 exec_aten::optional<exec_aten::Tensor> bias =
1064 exec_aten::optional<exec_aten::Tensor>(
1065 tfFloat.make({3}, {1.7, 2.2, 3.3}));
1066 bool training = true;
1067 double momentum = 1e-3;
1068 double eps = 1e-5;
1069 exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2});
1070 exec_aten::Tensor out1 = tfFloat.zeros({3});
1071 exec_aten::Tensor out2 = tfFloat.zeros({3});
1072 exec_aten::Tensor out0_expected = tfFloat.make(
1073 {2, 3, 2, 2},
1074 {0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883,
1075 1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688,
1076 2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404,
1077 2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656});
1078 exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1079 exec_aten::Tensor out2_expected =
1080 tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1081 op_native_batch_norm_legit_no_stats_out(
1082 input, weight, bias, training, momentum, eps, out0, out1, out2);
1083 EXPECT_TENSOR_CLOSE(out0, out0_expected);
1084 EXPECT_TENSOR_CLOSE(out1, out1_expected);
1085 EXPECT_TENSOR_CLOSE(out2, out2_expected);
1086 }
1087