xref: /aosp_15_r20/external/eigen/test/numext.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li // This file is part of Eigen, a lightweight C++ template library
2*bf2c3715SXin Li // for linear algebra.
3*bf2c3715SXin Li //
4*bf2c3715SXin Li // Copyright (C) 2017 Gael Guennebaud <[email protected]>
5*bf2c3715SXin Li //
6*bf2c3715SXin Li // This Source Code Form is subject to the terms of the Mozilla
7*bf2c3715SXin Li // Public License v. 2.0. If a copy of the MPL was not distributed
8*bf2c3715SXin Li // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9*bf2c3715SXin Li 
10*bf2c3715SXin Li #include "main.h"
11*bf2c3715SXin Li 
12*bf2c3715SXin Li template<typename T, typename U>
check_if_equal_or_nans(const T & actual,const U & expected)13*bf2c3715SXin Li bool check_if_equal_or_nans(const T& actual, const U& expected) {
14*bf2c3715SXin Li   return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
15*bf2c3715SXin Li }
16*bf2c3715SXin Li 
17*bf2c3715SXin Li template<typename T, typename U>
check_if_equal_or_nans(const std::complex<T> & actual,const std::complex<U> & expected)18*bf2c3715SXin Li bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
19*bf2c3715SXin Li   return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
20*bf2c3715SXin Li          && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
21*bf2c3715SXin Li }
22*bf2c3715SXin Li 
23*bf2c3715SXin Li template<typename T, typename U>
test_is_equal_or_nans(const T & actual,const U & expected)24*bf2c3715SXin Li bool test_is_equal_or_nans(const T& actual, const U& expected)
25*bf2c3715SXin Li {
26*bf2c3715SXin Li     if (check_if_equal_or_nans(actual, expected)) {
27*bf2c3715SXin Li       return true;
28*bf2c3715SXin Li     }
29*bf2c3715SXin Li 
30*bf2c3715SXin Li     // false:
31*bf2c3715SXin Li     std::cerr
32*bf2c3715SXin Li         << "\n    actual   = " << actual
33*bf2c3715SXin Li         << "\n    expected = " << expected << "\n\n";
34*bf2c3715SXin Li     return false;
35*bf2c3715SXin Li }
36*bf2c3715SXin Li 
37*bf2c3715SXin Li #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
38*bf2c3715SXin Li 
39*bf2c3715SXin Li template<typename T>
check_abs()40*bf2c3715SXin Li void check_abs() {
41*bf2c3715SXin Li   typedef typename NumTraits<T>::Real Real;
42*bf2c3715SXin Li   Real zero(0);
43*bf2c3715SXin Li 
44*bf2c3715SXin Li   if(NumTraits<T>::IsSigned)
45*bf2c3715SXin Li     VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
46*bf2c3715SXin Li   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
47*bf2c3715SXin Li   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
48*bf2c3715SXin Li 
49*bf2c3715SXin Li   for(int k=0; k<100; ++k)
50*bf2c3715SXin Li   {
51*bf2c3715SXin Li     T x = internal::random<T>();
52*bf2c3715SXin Li     if(!internal::is_same<T,bool>::value)
53*bf2c3715SXin Li       x = x/Real(2);
54*bf2c3715SXin Li     if(NumTraits<T>::IsSigned)
55*bf2c3715SXin Li     {
56*bf2c3715SXin Li       VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
57*bf2c3715SXin Li       VERIFY( numext::abs(-x) >= zero );
58*bf2c3715SXin Li     }
59*bf2c3715SXin Li     VERIFY( numext::abs(x) >= zero );
60*bf2c3715SXin Li     VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
61*bf2c3715SXin Li   }
62*bf2c3715SXin Li }
63*bf2c3715SXin Li 
64*bf2c3715SXin Li template<typename T>
check_arg()65*bf2c3715SXin Li void check_arg() {
66*bf2c3715SXin Li   typedef typename NumTraits<T>::Real Real;
67*bf2c3715SXin Li   VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
68*bf2c3715SXin Li   VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
69*bf2c3715SXin Li 
70*bf2c3715SXin Li   for(int k=0; k<100; ++k)
71*bf2c3715SXin Li   {
72*bf2c3715SXin Li     T x = internal::random<T>();
73*bf2c3715SXin Li     Real y = numext::arg(x);
74*bf2c3715SXin Li     VERIFY_IS_APPROX( y, std::arg(x) );
75*bf2c3715SXin Li   }
76*bf2c3715SXin Li }
77*bf2c3715SXin Li 
78*bf2c3715SXin Li template<typename T>
79*bf2c3715SXin Li struct check_sqrt_impl {
runcheck_sqrt_impl80*bf2c3715SXin Li   static void run() {
81*bf2c3715SXin Li     for (int i=0; i<1000; ++i) {
82*bf2c3715SXin Li       const T x = numext::abs(internal::random<T>());
83*bf2c3715SXin Li       const T sqrtx = numext::sqrt(x);
84*bf2c3715SXin Li       VERIFY_IS_APPROX(sqrtx*sqrtx, x);
85*bf2c3715SXin Li     }
86*bf2c3715SXin Li 
87*bf2c3715SXin Li     // Corner cases.
88*bf2c3715SXin Li     const T zero = T(0);
89*bf2c3715SXin Li     const T one = T(1);
90*bf2c3715SXin Li     const T inf = std::numeric_limits<T>::infinity();
91*bf2c3715SXin Li     const T nan = std::numeric_limits<T>::quiet_NaN();
92*bf2c3715SXin Li     VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
93*bf2c3715SXin Li     VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
94*bf2c3715SXin Li     VERIFY((numext::isnan)(numext::sqrt(nan)));
95*bf2c3715SXin Li     VERIFY((numext::isnan)(numext::sqrt(-one)));
96*bf2c3715SXin Li   }
97*bf2c3715SXin Li };
98*bf2c3715SXin Li 
99*bf2c3715SXin Li template<typename T>
100*bf2c3715SXin Li struct check_sqrt_impl<std::complex<T>  > {
runcheck_sqrt_impl101*bf2c3715SXin Li   static void run() {
102*bf2c3715SXin Li     typedef typename std::complex<T> ComplexT;
103*bf2c3715SXin Li 
104*bf2c3715SXin Li     for (int i=0; i<1000; ++i) {
105*bf2c3715SXin Li       const ComplexT x = internal::random<ComplexT>();
106*bf2c3715SXin Li       const ComplexT sqrtx = numext::sqrt(x);
107*bf2c3715SXin Li       VERIFY_IS_APPROX(sqrtx*sqrtx, x);
108*bf2c3715SXin Li     }
109*bf2c3715SXin Li 
110*bf2c3715SXin Li     // Corner cases.
111*bf2c3715SXin Li     const T zero = T(0);
112*bf2c3715SXin Li     const T one = T(1);
113*bf2c3715SXin Li     const T inf = std::numeric_limits<T>::infinity();
114*bf2c3715SXin Li     const T nan = std::numeric_limits<T>::quiet_NaN();
115*bf2c3715SXin Li 
116*bf2c3715SXin Li     // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
117*bf2c3715SXin Li     const int kNumCorners = 20;
118*bf2c3715SXin Li     const ComplexT corners[kNumCorners][2] = {
119*bf2c3715SXin Li       {ComplexT(zero, zero), ComplexT(zero, zero)},
120*bf2c3715SXin Li       {ComplexT(-zero, zero), ComplexT(zero, zero)},
121*bf2c3715SXin Li       {ComplexT(zero, -zero), ComplexT(zero, zero)},
122*bf2c3715SXin Li       {ComplexT(-zero, -zero), ComplexT(zero, zero)},
123*bf2c3715SXin Li       {ComplexT(one, inf), ComplexT(inf, inf)},
124*bf2c3715SXin Li       {ComplexT(nan, inf), ComplexT(inf, inf)},
125*bf2c3715SXin Li       {ComplexT(one, -inf), ComplexT(inf, -inf)},
126*bf2c3715SXin Li       {ComplexT(nan, -inf), ComplexT(inf, -inf)},
127*bf2c3715SXin Li       {ComplexT(-inf, one), ComplexT(zero, inf)},
128*bf2c3715SXin Li       {ComplexT(inf, one), ComplexT(inf, zero)},
129*bf2c3715SXin Li       {ComplexT(-inf, -one), ComplexT(zero, -inf)},
130*bf2c3715SXin Li       {ComplexT(inf, -one), ComplexT(inf, -zero)},
131*bf2c3715SXin Li       {ComplexT(-inf, nan), ComplexT(nan, inf)},
132*bf2c3715SXin Li       {ComplexT(inf, nan), ComplexT(inf, nan)},
133*bf2c3715SXin Li       {ComplexT(zero, nan), ComplexT(nan, nan)},
134*bf2c3715SXin Li       {ComplexT(one, nan), ComplexT(nan, nan)},
135*bf2c3715SXin Li       {ComplexT(nan, zero), ComplexT(nan, nan)},
136*bf2c3715SXin Li       {ComplexT(nan, one), ComplexT(nan, nan)},
137*bf2c3715SXin Li       {ComplexT(nan, -one), ComplexT(nan, nan)},
138*bf2c3715SXin Li       {ComplexT(nan, nan), ComplexT(nan, nan)},
139*bf2c3715SXin Li     };
140*bf2c3715SXin Li 
141*bf2c3715SXin Li     for (int i=0; i<kNumCorners; ++i) {
142*bf2c3715SXin Li       const ComplexT& x = corners[i][0];
143*bf2c3715SXin Li       const ComplexT sqrtx = corners[i][1];
144*bf2c3715SXin Li       VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
145*bf2c3715SXin Li     }
146*bf2c3715SXin Li   }
147*bf2c3715SXin Li };
148*bf2c3715SXin Li 
149*bf2c3715SXin Li template<typename T>
check_sqrt()150*bf2c3715SXin Li void check_sqrt() {
151*bf2c3715SXin Li   check_sqrt_impl<T>::run();
152*bf2c3715SXin Li }
153*bf2c3715SXin Li 
154*bf2c3715SXin Li template<typename T>
155*bf2c3715SXin Li struct check_rsqrt_impl {
runcheck_rsqrt_impl156*bf2c3715SXin Li   static void run() {
157*bf2c3715SXin Li     const T zero = T(0);
158*bf2c3715SXin Li     const T one = T(1);
159*bf2c3715SXin Li     const T inf = std::numeric_limits<T>::infinity();
160*bf2c3715SXin Li     const T nan = std::numeric_limits<T>::quiet_NaN();
161*bf2c3715SXin Li 
162*bf2c3715SXin Li     for (int i=0; i<1000; ++i) {
163*bf2c3715SXin Li       const T x = numext::abs(internal::random<T>());
164*bf2c3715SXin Li       const T rsqrtx = numext::rsqrt(x);
165*bf2c3715SXin Li       const T invx = one / x;
166*bf2c3715SXin Li       VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
167*bf2c3715SXin Li     }
168*bf2c3715SXin Li 
169*bf2c3715SXin Li     // Corner cases.
170*bf2c3715SXin Li     VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
171*bf2c3715SXin Li     VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
172*bf2c3715SXin Li     VERIFY((numext::isnan)(numext::rsqrt(nan)));
173*bf2c3715SXin Li     VERIFY((numext::isnan)(numext::rsqrt(-one)));
174*bf2c3715SXin Li   }
175*bf2c3715SXin Li };
176*bf2c3715SXin Li 
177*bf2c3715SXin Li template<typename T>
178*bf2c3715SXin Li struct check_rsqrt_impl<std::complex<T> > {
runcheck_rsqrt_impl179*bf2c3715SXin Li   static void run() {
180*bf2c3715SXin Li     typedef typename std::complex<T> ComplexT;
181*bf2c3715SXin Li     const T zero = T(0);
182*bf2c3715SXin Li     const T one = T(1);
183*bf2c3715SXin Li     const T inf = std::numeric_limits<T>::infinity();
184*bf2c3715SXin Li     const T nan = std::numeric_limits<T>::quiet_NaN();
185*bf2c3715SXin Li 
186*bf2c3715SXin Li     for (int i=0; i<1000; ++i) {
187*bf2c3715SXin Li       const ComplexT x = internal::random<ComplexT>();
188*bf2c3715SXin Li       const ComplexT invx = ComplexT(one, zero) / x;
189*bf2c3715SXin Li       const ComplexT rsqrtx = numext::rsqrt(x);
190*bf2c3715SXin Li       VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
191*bf2c3715SXin Li     }
192*bf2c3715SXin Li 
193*bf2c3715SXin Li     // GCC and MSVC differ in their treatment of 1/(0 + 0i)
194*bf2c3715SXin Li     //   GCC/clang = (inf, nan)
195*bf2c3715SXin Li     //   MSVC = (nan, nan)
196*bf2c3715SXin Li     // and 1 / (x + inf i)
197*bf2c3715SXin Li     //   GCC/clang = (0, 0)
198*bf2c3715SXin Li     //   MSVC = (nan, nan)
199*bf2c3715SXin Li     #if (EIGEN_COMP_GNUC)
200*bf2c3715SXin Li     {
201*bf2c3715SXin Li       const int kNumCorners = 20;
202*bf2c3715SXin Li       const ComplexT corners[kNumCorners][2] = {
203*bf2c3715SXin Li         // Only consistent across GCC, clang
204*bf2c3715SXin Li         {ComplexT(zero, zero), ComplexT(zero, zero)},
205*bf2c3715SXin Li         {ComplexT(-zero, zero), ComplexT(zero, zero)},
206*bf2c3715SXin Li         {ComplexT(zero, -zero), ComplexT(zero, zero)},
207*bf2c3715SXin Li         {ComplexT(-zero, -zero), ComplexT(zero, zero)},
208*bf2c3715SXin Li         {ComplexT(one, inf), ComplexT(inf, inf)},
209*bf2c3715SXin Li         {ComplexT(nan, inf), ComplexT(inf, inf)},
210*bf2c3715SXin Li         {ComplexT(one, -inf), ComplexT(inf, -inf)},
211*bf2c3715SXin Li         {ComplexT(nan, -inf), ComplexT(inf, -inf)},
212*bf2c3715SXin Li         // Consistent across GCC, clang, MSVC
213*bf2c3715SXin Li         {ComplexT(-inf, one), ComplexT(zero, inf)},
214*bf2c3715SXin Li         {ComplexT(inf, one), ComplexT(inf, zero)},
215*bf2c3715SXin Li         {ComplexT(-inf, -one), ComplexT(zero, -inf)},
216*bf2c3715SXin Li         {ComplexT(inf, -one), ComplexT(inf, -zero)},
217*bf2c3715SXin Li         {ComplexT(-inf, nan), ComplexT(nan, inf)},
218*bf2c3715SXin Li         {ComplexT(inf, nan), ComplexT(inf, nan)},
219*bf2c3715SXin Li         {ComplexT(zero, nan), ComplexT(nan, nan)},
220*bf2c3715SXin Li         {ComplexT(one, nan), ComplexT(nan, nan)},
221*bf2c3715SXin Li         {ComplexT(nan, zero), ComplexT(nan, nan)},
222*bf2c3715SXin Li         {ComplexT(nan, one), ComplexT(nan, nan)},
223*bf2c3715SXin Li         {ComplexT(nan, -one), ComplexT(nan, nan)},
224*bf2c3715SXin Li         {ComplexT(nan, nan), ComplexT(nan, nan)},
225*bf2c3715SXin Li       };
226*bf2c3715SXin Li 
227*bf2c3715SXin Li       for (int i=0; i<kNumCorners; ++i) {
228*bf2c3715SXin Li         const ComplexT& x = corners[i][0];
229*bf2c3715SXin Li         const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
230*bf2c3715SXin Li         VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
231*bf2c3715SXin Li       }
232*bf2c3715SXin Li     }
233*bf2c3715SXin Li     #endif
234*bf2c3715SXin Li   }
235*bf2c3715SXin Li };
236*bf2c3715SXin Li 
237*bf2c3715SXin Li template<typename T>
check_rsqrt()238*bf2c3715SXin Li void check_rsqrt() {
239*bf2c3715SXin Li   check_rsqrt_impl<T>::run();
240*bf2c3715SXin Li }
241*bf2c3715SXin Li 
EIGEN_DECLARE_TEST(numext)242*bf2c3715SXin Li EIGEN_DECLARE_TEST(numext) {
243*bf2c3715SXin Li   for(int k=0; k<g_repeat; ++k)
244*bf2c3715SXin Li   {
245*bf2c3715SXin Li     CALL_SUBTEST( check_abs<bool>() );
246*bf2c3715SXin Li     CALL_SUBTEST( check_abs<signed char>() );
247*bf2c3715SXin Li     CALL_SUBTEST( check_abs<unsigned char>() );
248*bf2c3715SXin Li     CALL_SUBTEST( check_abs<short>() );
249*bf2c3715SXin Li     CALL_SUBTEST( check_abs<unsigned short>() );
250*bf2c3715SXin Li     CALL_SUBTEST( check_abs<int>() );
251*bf2c3715SXin Li     CALL_SUBTEST( check_abs<unsigned int>() );
252*bf2c3715SXin Li     CALL_SUBTEST( check_abs<long>() );
253*bf2c3715SXin Li     CALL_SUBTEST( check_abs<unsigned long>() );
254*bf2c3715SXin Li     CALL_SUBTEST( check_abs<half>() );
255*bf2c3715SXin Li     CALL_SUBTEST( check_abs<bfloat16>() );
256*bf2c3715SXin Li     CALL_SUBTEST( check_abs<float>() );
257*bf2c3715SXin Li     CALL_SUBTEST( check_abs<double>() );
258*bf2c3715SXin Li     CALL_SUBTEST( check_abs<long double>() );
259*bf2c3715SXin Li     CALL_SUBTEST( check_abs<std::complex<float> >() );
260*bf2c3715SXin Li     CALL_SUBTEST( check_abs<std::complex<double> >() );
261*bf2c3715SXin Li 
262*bf2c3715SXin Li     CALL_SUBTEST( check_arg<std::complex<float> >() );
263*bf2c3715SXin Li     CALL_SUBTEST( check_arg<std::complex<double> >() );
264*bf2c3715SXin Li 
265*bf2c3715SXin Li     CALL_SUBTEST( check_sqrt<float>() );
266*bf2c3715SXin Li     CALL_SUBTEST( check_sqrt<double>() );
267*bf2c3715SXin Li     CALL_SUBTEST( check_sqrt<std::complex<float> >() );
268*bf2c3715SXin Li     CALL_SUBTEST( check_sqrt<std::complex<double> >() );
269*bf2c3715SXin Li 
270*bf2c3715SXin Li     CALL_SUBTEST( check_rsqrt<float>() );
271*bf2c3715SXin Li     CALL_SUBTEST( check_rsqrt<double>() );
272*bf2c3715SXin Li     CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
273*bf2c3715SXin Li     CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
274*bf2c3715SXin Li   }
275*bf2c3715SXin Li }
276