xref: /aosp_15_r20/external/llvm-libc/test/src/math/FmaTest.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- Utility class to test different flavors of fma --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
10 #define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
11 
12 #include "src/stdlib/rand.h"
13 #include "src/stdlib/srand.h"
14 #include "test/UnitTest/FEnvSafeTest.h"
15 #include "test/UnitTest/FPMatcher.h"
16 #include "test/UnitTest/Test.h"
17 #include "utils/MPFRWrapper/MPFRUtils.h"
18 
19 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
20 
21 template <typename OutType, typename InType = OutType>
22 class FmaTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
23 
24   struct OutConstants {
25     DECLARE_SPECIAL_CONSTANTS(OutType)
26   };
27 
28   struct InConstants {
29     DECLARE_SPECIAL_CONSTANTS(InType)
30   };
31 
32   using OutFPBits = typename OutConstants::FPBits;
33   using OutStorageType = typename OutConstants::StorageType;
34   using InFPBits = typename InConstants::FPBits;
35   using InStorageType = typename InConstants::StorageType;
36 
37   static constexpr OutStorageType OUT_MIN_NORMAL_U =
38       OutFPBits::min_normal().uintval();
39   static constexpr InStorageType IN_MAX_NORMAL_U =
40       InFPBits::max_normal().uintval();
41   static constexpr InStorageType IN_MIN_NORMAL_U =
42       InFPBits::min_normal().uintval();
43   static constexpr InStorageType IN_MAX_SUBNORMAL_U =
44       InFPBits::max_subnormal().uintval();
45   static constexpr InStorageType IN_MIN_SUBNORMAL_U =
46       InFPBits::min_subnormal().uintval();
47 
get_random_bit_pattern()48   InStorageType get_random_bit_pattern() {
49     InStorageType bits{0};
50     for (InStorageType i = 0; i < sizeof(InStorageType) / 2; ++i) {
51       bits = (bits << 2) + static_cast<uint16_t>(LIBC_NAMESPACE::rand());
52     }
53     return bits;
54   }
55 
56 public:
57   using FmaFunc = OutType (*)(InType, InType, InType);
58 
test_subnormal_range(FmaFunc func)59   void test_subnormal_range(FmaFunc func) {
60     constexpr InStorageType COUNT = 100'001;
61     constexpr InStorageType STEP =
62         (IN_MAX_SUBNORMAL_U - IN_MIN_SUBNORMAL_U) / COUNT;
63     LIBC_NAMESPACE::srand(1);
64     for (InStorageType v = IN_MIN_SUBNORMAL_U, w = IN_MAX_SUBNORMAL_U;
65          v <= IN_MAX_SUBNORMAL_U && w >= IN_MIN_SUBNORMAL_U;
66          v += STEP, w -= STEP) {
67       InType x = InFPBits(get_random_bit_pattern()).get_val();
68       InType y = InFPBits(v).get_val();
69       InType z = InFPBits(w).get_val();
70       mpfr::TernaryInput<InType> input{x, y, z};
71       ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, func(x, y, z),
72                                      0.5);
73     }
74   }
75 
76   void test_normal_range(FmaFunc func) {
77     constexpr InStorageType COUNT = 100'001;
78     constexpr InStorageType STEP = (IN_MAX_NORMAL_U - IN_MIN_NORMAL_U) / COUNT;
79     LIBC_NAMESPACE::srand(1);
80     for (InStorageType v = IN_MIN_NORMAL_U, w = IN_MAX_NORMAL_U;
81          v <= IN_MAX_NORMAL_U && w >= IN_MIN_NORMAL_U; v += STEP, w -= STEP) {
82       InType x = InFPBits(v).get_val();
83       InType y = InFPBits(w).get_val();
84       InType z = InFPBits(get_random_bit_pattern()).get_val();
85       mpfr::TernaryInput<InType> input{x, y, z};
86       ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, func(x, y, z),
87                                      0.5);
88     }
89   }
90 };
91 
92 #define LIST_FMA_TESTS(T, func)                                                \
93   using LlvmLibcFmaTest = FmaTestTemplate<T>;                                  \
94   TEST_F(LlvmLibcFmaTest, SubnormalRange) { test_subnormal_range(&func); }     \
95   TEST_F(LlvmLibcFmaTest, NormalRange) { test_normal_range(&func); }
96 
97 #define LIST_NARROWING_FMA_TESTS(OutType, InType, func)                        \
98   using LlvmLibcFmaTest = FmaTestTemplate<OutType, InType>;                    \
99   TEST_F(LlvmLibcFmaTest, SubnormalRange) { test_subnormal_range(&func); }     \
100   TEST_F(LlvmLibcFmaTest, NormalRange) { test_normal_range(&func); }
101 
102 #endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
103