1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10
11 #include <executorch/kernels/optimized/cpu/moments_utils.h>
12
13 #include <vector>
14
15 #define TEST_FORALL_FLOAT_CTYPES(_) \
16 _<double>(); \
17 _<float>(); \
18 _<short>();
19
20 namespace {
21
22 // Check if a float value is close to a reference value
23 template <class T>
is_close(T val,float ref,float tol=1e-5)24 bool is_close(T val, float ref, float tol = 1e-5) {
25 T diff = std::abs(val - static_cast<T>(ref));
26 return diff <= static_cast<T>(tol);
27 }
28
29 } // namespace
30
31 template <class CTYPE>
test_calc_moments()32 void test_calc_moments() {
33 using torch::executor::native::RowwiseMoments;
34
35 std::vector<CTYPE> in({2, 3, 4, 5, 9, 10, 12, 13});
36
37 float mean;
38 float variance;
39 const CTYPE* in_data = in.data();
40 std::tie(mean, variance) = RowwiseMoments(in_data, 8);
41
42 EXPECT_TRUE(is_close<CTYPE>(mean, 7.25f));
43 EXPECT_TRUE(is_close<CTYPE>(variance, 15.9375f));
44 }
45
TEST(MomentsUtilTest,CalculateMoments)46 TEST(MomentsUtilTest, CalculateMoments) {
47 TEST_FORALL_FLOAT_CTYPES(test_calc_moments)
48 }
49