xref: /aosp_15_r20/external/executorch/kernels/optimized/test/moments_utils_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <executorch/kernels/optimized/cpu/moments_utils.h>
12 
13 #include <vector>
14 
15 #define TEST_FORALL_FLOAT_CTYPES(_) \
16   _<double>();                      \
17   _<float>();                       \
18   _<short>();
19 
20 namespace {
21 
22 // Check if a float value is close to a reference value
23 template <class T>
is_close(T val,float ref,float tol=1e-5)24 bool is_close(T val, float ref, float tol = 1e-5) {
25   T diff = std::abs(val - static_cast<T>(ref));
26   return diff <= static_cast<T>(tol);
27 }
28 
29 } // namespace
30 
31 template <class CTYPE>
test_calc_moments()32 void test_calc_moments() {
33   using torch::executor::native::RowwiseMoments;
34 
35   std::vector<CTYPE> in({2, 3, 4, 5, 9, 10, 12, 13});
36 
37   float mean;
38   float variance;
39   const CTYPE* in_data = in.data();
40   std::tie(mean, variance) = RowwiseMoments(in_data, 8);
41 
42   EXPECT_TRUE(is_close<CTYPE>(mean, 7.25f));
43   EXPECT_TRUE(is_close<CTYPE>(variance, 15.9375f));
44 }
45 
TEST(MomentsUtilTest,CalculateMoments)46 TEST(MomentsUtilTest, CalculateMoments) {
47   TEST_FORALL_FLOAT_CTYPES(test_calc_moments)
48 }
49