1 //===-- Utility class to test integer sqrt ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "test/UnitTest/FPMatcher.h" 10 #include "test/UnitTest/Test.h" 11 12 #include "src/__support/CPP/bit.h" 13 #include "src/__support/FPUtil/BasicOperations.h" 14 #include "src/__support/FPUtil/sqrt.h" 15 #include "src/__support/fixed_point/fx_rep.h" 16 #include "src/__support/fixed_point/sqrt.h" 17 18 template <typename T> class ISqrtTest : public LIBC_NAMESPACE::testing::Test { 19 20 using OutType = 21 typename LIBC_NAMESPACE::fixed_point::internal::SqrtConfig<T>::OutType; 22 using FXRep = LIBC_NAMESPACE::fixed_point::FXRep<OutType>; 23 static constexpr OutType zero = FXRep::ZERO(); 24 static constexpr OutType one = static_cast<OutType>(1); 25 static constexpr OutType eps = FXRep::EPS(); 26 27 public: 28 typedef OutType (*SqrtFunc)(T); 29 testSpecificInput(T input,OutType result,double expected,double tolerance)30 void testSpecificInput(T input, OutType result, double expected, 31 double tolerance) { 32 double y_d = static_cast<double>(result); 33 double errors = LIBC_NAMESPACE::fputil::abs((y_d / expected) - 1.0); 34 if (errors > tolerance) { 35 // Print out the failure input and output. 36 EXPECT_EQ(input, T(0)); 37 EXPECT_EQ(result, zero); 38 } 39 ASSERT_TRUE(errors <= tolerance); 40 } 41 testSpecialNumbers(SqrtFunc func)42 void testSpecialNumbers(SqrtFunc func) { 43 EXPECT_EQ(zero, func(T(0))); 44 45 EXPECT_EQ(one, func(T(1))); 46 EXPECT_EQ(static_cast<OutType>(2.0), func(T(4))); 47 EXPECT_EQ(static_cast<OutType>(4.0), func(T(16))); 48 EXPECT_EQ(static_cast<OutType>(16.0), func(T(256))); 49 50 constexpr int COUNT = 255; 51 constexpr double ERR = 3.0 * static_cast<double>(eps); 52 double x_d = 0.0; 53 T x = 0; 54 for (int i = 0; i < COUNT; ++i) { 55 x_d += 1.0; 56 ++x; 57 OutType result = func(x); 58 double expected = LIBC_NAMESPACE::fputil::sqrt<double>(x_d); 59 testSpecificInput(x, result, expected, ERR); 60 } 61 } 62 }; 63 64 #define LIST_ISQRT_TESTS(Name, T, func) \ 65 using LlvmLibcISqrt##Name##Test = ISqrtTest<T>; \ 66 TEST_F(LlvmLibcISqrt##Name##Test, SpecialNumbers) { \ 67 testSpecialNumbers(&func); \ 68 } \ 69 static_assert(true, "Require semicolon.") 70