xref: /aosp_15_r20/external/llvm-libc/test/src/math/exhaustive/exhaustive_test.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- Exhaustive test template for math functions -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "src/__support/CPP/type_traits.h"
10 #include "src/__support/FPUtil/FPBits.h"
11 #include "src/__support/macros/properties/types.h"
12 #include "test/UnitTest/FPMatcher.h"
13 #include "test/UnitTest/Test.h"
14 #include "utils/MPFRWrapper/MPFRUtils.h"
15 
16 #include <atomic>
17 #include <functional>
18 #include <iostream>
19 #include <mutex>
20 #include <sstream>
21 #include <thread>
22 #include <vector>
23 
24 // To test exhaustively for inputs in the range [start, stop) in parallel:
25 // 1. Define a Checker class with:
26 //    - FloatType: define floating point type to be used.
27 //    - FPBits: fputil::FPBits<FloatType>.
28 //    - StorageType: define bit type for the corresponding floating point type.
29 //    - uint64_t check(start, stop, rounding_mode): a method to test in given
30 //          range for a given rounding mode, which returns the number of
31 //          failures.
32 // 2. Use LlvmLibcExhaustiveMathTest<Checker> class
33 // 3. Call: test_full_range(start, stop, nthreads, rounding)
34 //       or test_full_range_all_roundings(start, stop).
35 // * For single input single output math function, use the convenient template:
36 //   LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>.
37 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
38 
39 template <typename OutType, typename InType = OutType>
40 using UnaryOp = OutType(InType);
41 
42 template <typename OutType, typename InType, mpfr::Operation Op,
43           UnaryOp<OutType, InType> Func>
44 struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
45   using FloatType = InType;
46   using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
47   using StorageType = typename FPBits::StorageType;
48 
49   // Check in a range, return the number of failures.
checkUnaryOpChecker50   uint64_t check(StorageType start, StorageType stop,
51                  mpfr::RoundingMode rounding) {
52     mpfr::ForceRoundingMode r(rounding);
53     if (!r.success)
54       return (stop > start);
55     StorageType bits = start;
56     uint64_t failed = 0;
57     do {
58       FPBits xbits(bits);
59       FloatType x = xbits.get_val();
60       bool correct =
61           TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding);
62       failed += (!correct);
63       // Uncomment to print out failed values.
64       if (!correct) {
65         EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
66       }
67     } while (bits++ < stop);
68     return failed;
69   }
70 };
71 
72 template <typename OutType, typename InType = OutType>
73 using BinaryOp = OutType(InType, InType);
74 
75 template <typename OutType, typename InType, mpfr::Operation Op,
76           BinaryOp<OutType, InType> Func>
77 struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
78   using FloatType = InType;
79   using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
80   using StorageType = typename FPBits::StorageType;
81 
82   // Check in a range, return the number of failures.
checkBinaryOpChecker83   uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
84                  StorageType y_stop, mpfr::RoundingMode rounding) {
85     mpfr::ForceRoundingMode r(rounding);
86     if (!r.success)
87       return x_stop > x_start || y_stop > y_start;
88     StorageType xbits = x_start;
89     uint64_t failed = 0;
90     do {
91       FloatType x = FPBits(xbits).get_val();
92       StorageType ybits = y_start;
93       do {
94         FloatType y = FPBits(ybits).get_val();
95         mpfr::BinaryInput<FloatType> input{x, y};
96         bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y),
97                                                          0.5, rounding);
98         failed += (!correct);
99         // Uncomment to print out failed values.
100         if (!correct) {
101           EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
102         }
103       } while (ybits++ < y_stop);
104     } while (xbits++ < x_stop);
105     return failed;
106   }
107 };
108 
109 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
110 //   StorageType and check method.
111 template <typename Checker, size_t Increment = 1 << 20>
112 struct LlvmLibcExhaustiveMathTest
113     : public virtual LIBC_NAMESPACE::testing::Test,
114       public Checker {
115   using FloatType = typename Checker::FloatType;
116   using FPBits = typename Checker::FPBits;
117   using StorageType = typename Checker::StorageType;
118 
explain_failed_rangeLlvmLibcExhaustiveMathTest119   void explain_failed_range(std::stringstream &msg, StorageType x_begin,
120                             StorageType x_end) {
121 #ifdef LIBC_TYPES_HAS_FLOAT16
122     using T = LIBC_NAMESPACE::cpp::conditional_t<
123         LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>;
124 #else
125     using T = FloatType;
126 #endif
127 
128     msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x"
129         << x_end << "), [" << std::hexfloat
130         << static_cast<T>(FPBits(x_begin).get_val()) << ", "
131         << static_cast<T>(FPBits(x_end).get_val()) << ")";
132   }
133 
explain_failed_rangeLlvmLibcExhaustiveMathTest134   void explain_failed_range(std::stringstream &msg, StorageType x_begin,
135                             StorageType x_end, StorageType y_begin,
136                             StorageType y_end) {
137     msg << "x ";
138     explain_failed_range(msg, x_begin, x_end);
139     msg << ", y ";
140     explain_failed_range(msg, y_begin, y_end);
141   }
142 
143   // Break [start, stop) into `nthreads` subintervals and apply *check to each
144   // subinterval in parallel.
145   template <typename... T>
test_full_rangeLlvmLibcExhaustiveMathTest146   void test_full_range(mpfr::RoundingMode rounding, StorageType start,
147                        StorageType stop, T... extra_range_bounds) {
148     int n_threads = std::thread::hardware_concurrency();
149     std::vector<std::thread> thread_list;
150     std::mutex mx_cur_val;
151     int current_percent = -1;
152     StorageType current_value = start;
153     std::atomic<uint64_t> failed(0);
154 
155     for (int i = 0; i < n_threads; ++i) {
156       thread_list.emplace_back([&, this]() {
157         while (true) {
158           StorageType range_begin, range_end;
159           int new_percent = -1;
160           {
161             std::lock_guard<std::mutex> lock(mx_cur_val);
162             if (current_value == stop)
163               return;
164 
165             range_begin = current_value;
166             if (stop >= Increment && stop - Increment >= current_value) {
167               range_end = current_value + Increment;
168             } else {
169               range_end = stop;
170             }
171             current_value = range_end;
172             int pc = 100.0 * (range_end - start) / (stop - start);
173             if (current_percent != pc) {
174               new_percent = pc;
175               current_percent = pc;
176             }
177           }
178           if (new_percent >= 0) {
179             std::stringstream msg;
180             msg << new_percent << "% is in process     \r";
181             std::cout << msg.str() << std::flush;
182           }
183 
184           uint64_t failed_in_range = Checker::check(
185               range_begin, range_end, extra_range_bounds..., rounding);
186           if (failed_in_range > 0) {
187             std::stringstream msg;
188             msg << "Test failed for " << std::dec << failed_in_range
189                 << " inputs in range: ";
190             explain_failed_range(msg, range_begin, range_end,
191                                  extra_range_bounds...);
192             msg << "\n";
193             std::cerr << msg.str() << std::flush;
194 
195             failed.fetch_add(failed_in_range);
196           }
197         }
198       });
199     }
200 
201     for (auto &thread : thread_list) {
202       if (thread.joinable()) {
203         thread.join();
204       }
205     }
206 
207     std::cout << std::endl;
208     std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
209     ASSERT_EQ(failed.load(), uint64_t(0));
210   }
211 
test_full_range_all_roundingsLlvmLibcExhaustiveMathTest212   void test_full_range_all_roundings(StorageType start, StorageType stop) {
213     std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
214               << ", 0x" << stop << ") --" << std::dec << std::endl;
215     test_full_range(mpfr::RoundingMode::Nearest, start, stop);
216 
217     std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
218               << ", 0x" << stop << ") --" << std::dec << std::endl;
219     test_full_range(mpfr::RoundingMode::Upward, start, stop);
220 
221     std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
222               << ", 0x" << stop << ") --" << std::dec << std::endl;
223     test_full_range(mpfr::RoundingMode::Downward, start, stop);
224 
225     std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
226               << start << ", 0x" << stop << ") --" << std::dec << std::endl;
227     test_full_range(mpfr::RoundingMode::TowardZero, start, stop);
228   };
229 
test_full_range_all_roundingsLlvmLibcExhaustiveMathTest230   void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
231                                      StorageType y_start, StorageType y_stop) {
232     std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex
233               << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
234               << ", 0x" << y_stop << ") --" << std::dec << std::endl;
235     test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
236                     y_stop);
237 
238     std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
239               << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
240               << ", 0x" << y_stop << ") --" << std::dec << std::endl;
241     test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
242                     y_stop);
243 
244     std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
245               << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
246               << ", 0x" << y_stop << ") --" << std::dec << std::endl;
247     test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
248                     y_stop);
249 
250     std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
251               << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
252               << ", 0x" << y_stop << ") --" << std::dec << std::endl;
253     test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
254                     y_stop);
255   };
256 };
257 
258 template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
259 using LlvmLibcUnaryOpExhaustiveMathTest =
260     LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
261 
262 template <typename OutType, typename InType, mpfr::Operation Op,
263           UnaryOp<OutType, InType> Func>
264 using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
265     LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
266 
267 template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
268 using LlvmLibcBinaryOpExhaustiveMathTest =
269     LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
270                                1 << 2>;
271