xref: /aosp_15_r20/external/gemmlowp/test/test_fixedpoint.cc (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // test_fixedpoint.cc: unit tests covering the fixedpoint/ directory.
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
18*5f39d1b3SJooyung Han 
19*5f39d1b3SJooyung Han #include <algorithm>
20*5f39d1b3SJooyung Han #include <cinttypes>
21*5f39d1b3SJooyung Han #include <cmath>
22*5f39d1b3SJooyung Han #include <cstdio>
23*5f39d1b3SJooyung Han #include <random>
24*5f39d1b3SJooyung Han #include <vector>
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han #include "../fixedpoint/fixedpoint.h"
27*5f39d1b3SJooyung Han #include "test.h"
28*5f39d1b3SJooyung Han 
29*5f39d1b3SJooyung Han namespace gemmlowp {
30*5f39d1b3SJooyung Han 
31*5f39d1b3SJooyung Han namespace {
32*5f39d1b3SJooyung Han 
33*5f39d1b3SJooyung Han template <typename T>
Load(const typename FixedPointRawTypeTraits<T>::ScalarRawType * src)34*5f39d1b3SJooyung Han T Load(const typename FixedPointRawTypeTraits<T>::ScalarRawType* src) {
35*5f39d1b3SJooyung Han   return *src;
36*5f39d1b3SJooyung Han }
37*5f39d1b3SJooyung Han template <typename T>
Store(typename FixedPointRawTypeTraits<T>::ScalarRawType * dst,T v)38*5f39d1b3SJooyung Han void Store(typename FixedPointRawTypeTraits<T>::ScalarRawType* dst, T v) {
39*5f39d1b3SJooyung Han   *dst = v;
40*5f39d1b3SJooyung Han }
41*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
42*5f39d1b3SJooyung Han template <>
Load(const std::int32_t * src)43*5f39d1b3SJooyung Han int32x4_t Load<int32x4_t>(const std::int32_t* src) {
44*5f39d1b3SJooyung Han   return vld1q_s32(src);
45*5f39d1b3SJooyung Han }
46*5f39d1b3SJooyung Han template <>
Load(const std::int16_t * src)47*5f39d1b3SJooyung Han int16x8_t Load<int16x8_t>(const std::int16_t* src) {
48*5f39d1b3SJooyung Han   return vld1q_s16(src);
49*5f39d1b3SJooyung Han }
50*5f39d1b3SJooyung Han template <>
Store(std::int32_t * dst,int32x4_t v)51*5f39d1b3SJooyung Han void Store<int32x4_t>(std::int32_t* dst, int32x4_t v) {
52*5f39d1b3SJooyung Han   vst1q_s32(dst, v);
53*5f39d1b3SJooyung Han }
54*5f39d1b3SJooyung Han template <>
Store(std::int16_t * dst,int16x8_t v)55*5f39d1b3SJooyung Han void Store<int16x8_t>(std::int16_t* dst, int16x8_t v) {
56*5f39d1b3SJooyung Han   vst1q_s16(dst, v);
57*5f39d1b3SJooyung Han }
58*5f39d1b3SJooyung Han #endif
59*5f39d1b3SJooyung Han #ifdef GEMMLOWP_SSE4
60*5f39d1b3SJooyung Han template <>
Load(const std::int32_t * src)61*5f39d1b3SJooyung Han __m128i Load<__m128i>(const std::int32_t* src) {
62*5f39d1b3SJooyung Han   return _mm_loadu_si128(reinterpret_cast<const __m128i*>(src));
63*5f39d1b3SJooyung Han }
64*5f39d1b3SJooyung Han template <>
Store(std::int32_t * dst,__m128i v)65*5f39d1b3SJooyung Han void Store<__m128i>(std::int32_t* dst, __m128i v) {
66*5f39d1b3SJooyung Han   _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v);
67*5f39d1b3SJooyung Han }
68*5f39d1b3SJooyung Han template <>
Load(const std::int16_t * src)69*5f39d1b3SJooyung Han int16x8_m128i Load<int16x8_m128i>(const std::int16_t* src) {
70*5f39d1b3SJooyung Han   return to_int16x8_m128i(
71*5f39d1b3SJooyung Han       _mm_loadu_si128(reinterpret_cast<const __m128i*>(src)));
72*5f39d1b3SJooyung Han }
73*5f39d1b3SJooyung Han template <>
Store(std::int16_t * dst,int16x8_m128i v)74*5f39d1b3SJooyung Han void Store<int16x8_m128i>(std::int16_t* dst, int16x8_m128i v) {
75*5f39d1b3SJooyung Han   _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v.v);
76*5f39d1b3SJooyung Han }
77*5f39d1b3SJooyung Han #endif
78*5f39d1b3SJooyung Han #ifdef GEMMLOWP_MSA
79*5f39d1b3SJooyung Han template <>
Load(const std::int32_t * src)80*5f39d1b3SJooyung Han v4i32 Load<v4i32>(const std::int32_t* src) {
81*5f39d1b3SJooyung Han   return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
82*5f39d1b3SJooyung Han }
83*5f39d1b3SJooyung Han template <>
Load(const std::int16_t * src)84*5f39d1b3SJooyung Han v8i16 Load<v8i16>(const std::int16_t* src) {
85*5f39d1b3SJooyung Han   return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
86*5f39d1b3SJooyung Han }
87*5f39d1b3SJooyung Han template <>
Store(std::int32_t * dst,v4i32 v)88*5f39d1b3SJooyung Han void Store<v4i32>(std::int32_t* dst, v4i32 v) {
89*5f39d1b3SJooyung Han   __builtin_msa_st_w(v, dst, 0);
90*5f39d1b3SJooyung Han }
91*5f39d1b3SJooyung Han template <>
Store(std::int16_t * dst,v8i16 v)92*5f39d1b3SJooyung Han void Store<v8i16>(std::int16_t* dst, v8i16 v) {
93*5f39d1b3SJooyung Han   __builtin_msa_st_h(v, dst, 0);
94*5f39d1b3SJooyung Han }
95*5f39d1b3SJooyung Han #endif
96*5f39d1b3SJooyung Han 
97*5f39d1b3SJooyung Han #ifdef GEMMLOWP_AVX2
98*5f39d1b3SJooyung Han template <>
Load(const std::int32_t * src)99*5f39d1b3SJooyung Han __m256i Load<__m256i>(const std::int32_t* src) {
100*5f39d1b3SJooyung Han   return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src));
101*5f39d1b3SJooyung Han }
102*5f39d1b3SJooyung Han 
103*5f39d1b3SJooyung Han template <>
Load(const std::int16_t * src)104*5f39d1b3SJooyung Han int16x16_m256i Load<int16x16_m256i>(const std::int16_t* src) {
105*5f39d1b3SJooyung Han   return to_int16x16_m256i(
106*5f39d1b3SJooyung Han       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src)));
107*5f39d1b3SJooyung Han }
108*5f39d1b3SJooyung Han 
109*5f39d1b3SJooyung Han template <>
Store(std::int32_t * dst,__m256i v)110*5f39d1b3SJooyung Han void Store<__m256i>(std::int32_t* dst, __m256i v) {
111*5f39d1b3SJooyung Han   _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
112*5f39d1b3SJooyung Han }
113*5f39d1b3SJooyung Han 
114*5f39d1b3SJooyung Han template <>
Store(std::int16_t * dst,int16x16_m256i v)115*5f39d1b3SJooyung Han void Store<int16x16_m256i>(std::int16_t* dst, int16x16_m256i v) {
116*5f39d1b3SJooyung Han   _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v.v);
117*5f39d1b3SJooyung Han }
118*5f39d1b3SJooyung Han #endif
119*5f39d1b3SJooyung Han 
120*5f39d1b3SJooyung Han template <typename tSimdType>
121*5f39d1b3SJooyung Han class TestFixedPoint {
122*5f39d1b3SJooyung Han  public:
123*5f39d1b3SJooyung Han   using SimdType = tSimdType;
124*5f39d1b3SJooyung Han   using SimdTypeTraits = FixedPointRawTypeTraits<SimdType>;
125*5f39d1b3SJooyung Han   using ScalarType = typename SimdTypeTraits::ScalarRawType;
126*5f39d1b3SJooyung Han   static constexpr int kSimdLanes = SimdTypeTraits::kLanes;
127*5f39d1b3SJooyung Han   static constexpr int kScalarTypeBits = 8 * sizeof(ScalarType);
128*5f39d1b3SJooyung Han 
129*5f39d1b3SJooyung Han   // Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
130*5f39d1b3SJooyung Han   // Most (though not all) of the fixedpoint functionality being tested
131*5f39d1b3SJooyung Han   // consists of functions taking one fixedpoint value and returning one
132*5f39d1b3SJooyung Han   // fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
133*5f39d1b3SJooyung Han   // We factor a lot of testing boilerplate into a common TestUnaryOp function
134*5f39d1b3SJooyung Han   // taking a "unary op" object that fully describes the function to be tested.
135*5f39d1b3SJooyung Han   // These objects inherit UnaryOpBase mostly as a means to share some default
136*5f39d1b3SJooyung Han   // values for some properties.
137*5f39d1b3SJooyung Han   //
138*5f39d1b3SJooyung Han   // An important design element here is that the fixed-point values are passed
139*5f39d1b3SJooyung Han   // around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
140*5f39d1b3SJooyung Han   // as higher-level FixedPoint objects. The motivation for this design is 1) to
141*5f39d1b3SJooyung Han   // avoid having to templatize everything in the tIntegerBits parameter of
142*5f39d1b3SJooyung Han   // class FixedPoint, and 2) to allow directly testing low-level functions
143*5f39d1b3SJooyung Han   // operating on raw types (e.g. RoundingDivideByPOT) without needlessly
144*5f39d1b3SJooyung Han   // requiring
145*5f39d1b3SJooyung Han   // wrapping raw values in FixedPoint objects.
146*5f39d1b3SJooyung Han   class UnaryOpBase {
147*5f39d1b3SJooyung Han    public:
148*5f39d1b3SJooyung Han     // Min bound of the input range of this op. For example, an op only handling
149*5f39d1b3SJooyung Han     // nonnegative values would return 0.
MinInput() const150*5f39d1b3SJooyung Han     ScalarType MinInput() const {
151*5f39d1b3SJooyung Han       return std::numeric_limits<ScalarType>::min();
152*5f39d1b3SJooyung Han     }
153*5f39d1b3SJooyung Han     // Max bound of the input range of this op. For example, an op only handling
154*5f39d1b3SJooyung Han     // nonpositive values would return 0.
MaxInput() const155*5f39d1b3SJooyung Han     ScalarType MaxInput() const {
156*5f39d1b3SJooyung Han       return std::numeric_limits<ScalarType>::max();
157*5f39d1b3SJooyung Han     }
158*5f39d1b3SJooyung Han     // Tolerated difference between actual and reference ScalarType values.
159*5f39d1b3SJooyung Han     // Note that the corresponding real-numbers tolerance depends on the number
160*5f39d1b3SJooyung Han     // of integer bits of the fixed-point representation of the results of this
161*5f39d1b3SJooyung Han     // op.
162*5f39d1b3SJooyung Han     // For example, for an op returning fixed-point values with 0 integer bits,
163*5f39d1b3SJooyung Han     // the correspondence between real-number values and raw values is
164*5f39d1b3SJooyung Han     // real_number = (2^31) * raw_value.
Tolerance() const165*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return 0; }
166*5f39d1b3SJooyung Han   };
167*5f39d1b3SJooyung Han 
168*5f39d1b3SJooyung Han   // Op wrapping RoundingDivideByPOT
169*5f39d1b3SJooyung Han   class RoundingDivideByPOTOp final : public UnaryOpBase {
170*5f39d1b3SJooyung Han    public:
RoundingDivideByPOTOp(int exponent)171*5f39d1b3SJooyung Han     RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
ReferenceOp(ScalarType x) const172*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
173*5f39d1b3SJooyung Han       const double d = static_cast<double>(x) / (1ll << exponent_);
174*5f39d1b3SJooyung Han       return static_cast<ScalarType>(std::round(d));
175*5f39d1b3SJooyung Han     }
176*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const177*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
178*5f39d1b3SJooyung Han       return RoundingDivideByPOT(x, exponent_);
179*5f39d1b3SJooyung Han     }
180*5f39d1b3SJooyung Han 
181*5f39d1b3SJooyung Han    private:
182*5f39d1b3SJooyung Han     const int exponent_;
183*5f39d1b3SJooyung Han   };
184*5f39d1b3SJooyung Han 
185*5f39d1b3SJooyung Han   // Op wrapping SaturatingRoundingMultiplyByPOT
186*5f39d1b3SJooyung Han   template <int tExponent>
187*5f39d1b3SJooyung Han   class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
188*5f39d1b3SJooyung Han    public:
ReferenceOp(ScalarType x) const189*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
190*5f39d1b3SJooyung Han       const double d = static_cast<double>(x) * std::pow(2., tExponent);
191*5f39d1b3SJooyung Han       const double clamp_min = std::numeric_limits<ScalarType>::min();
192*5f39d1b3SJooyung Han       const double clamp_max = std::numeric_limits<ScalarType>::max();
193*5f39d1b3SJooyung Han       const double clamped = std::min(clamp_max, std::max(clamp_min, d));
194*5f39d1b3SJooyung Han       return static_cast<ScalarType>(std::round(clamped));
195*5f39d1b3SJooyung Han     }
196*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const197*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
198*5f39d1b3SJooyung Han       return SaturatingRoundingMultiplyByPOT<tExponent>(x);
199*5f39d1b3SJooyung Han     }
200*5f39d1b3SJooyung Han   };
201*5f39d1b3SJooyung Han 
202*5f39d1b3SJooyung Han   // Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
203*5f39d1b3SJooyung Han   class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
204*5f39d1b3SJooyung Han       : public UnaryOpBase {
205*5f39d1b3SJooyung Han    public:
MinInput() const206*5f39d1b3SJooyung Han     ScalarType MinInput() const { return -(1 << (kScalarTypeBits - 3)); }
MaxInput() const207*5f39d1b3SJooyung Han     ScalarType MaxInput() const { return 0; }
Tolerance() const208*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 1; }
ReferenceOp(ScalarType x) const209*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
210*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, 0>;
211*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
212*5f39d1b3SJooyung Han       const double e = std::exp(d);
213*5f39d1b3SJooyung Han       return F::FromDouble(e).raw();
214*5f39d1b3SJooyung Han     }
215*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const216*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
217*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, 0>;
218*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
219*5f39d1b3SJooyung Han       const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
220*5f39d1b3SJooyung Han       return e.raw();
221*5f39d1b3SJooyung Han     }
222*5f39d1b3SJooyung Han   };
223*5f39d1b3SJooyung Han 
224*5f39d1b3SJooyung Han   // Op wrapping exp_on_negative_values
225*5f39d1b3SJooyung Han   template <int tIntegerBits>
226*5f39d1b3SJooyung Han   class ExpOnNegativeValuesOp final : public UnaryOpBase {
227*5f39d1b3SJooyung Han    public:
MaxInput() const228*5f39d1b3SJooyung Han     ScalarType MaxInput() const { return 0; }
Tolerance() const229*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 500 : 2; }
ReferenceOp(ScalarType x) const230*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
231*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, tIntegerBits>;
232*5f39d1b3SJooyung Han       using F0 = FixedPoint<ScalarType, 0>;
233*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
234*5f39d1b3SJooyung Han       const double e = std::exp(d);
235*5f39d1b3SJooyung Han       return F0::FromDouble(e).raw();
236*5f39d1b3SJooyung Han     }
237*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const238*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
239*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, tIntegerBits>;
240*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
241*5f39d1b3SJooyung Han       return exp_on_negative_values(f).raw();
242*5f39d1b3SJooyung Han     }
243*5f39d1b3SJooyung Han   };
244*5f39d1b3SJooyung Han 
245*5f39d1b3SJooyung Han   // Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
246*5f39d1b3SJooyung Han   class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
247*5f39d1b3SJooyung Han    public:
MinInput() const248*5f39d1b3SJooyung Han     ScalarType MinInput() const { return 0; }
Tolerance() const249*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 12 : 11; }
ReferenceOp(ScalarType x) const250*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
251*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, 0>;
252*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
253*5f39d1b3SJooyung Han       const double e = (1 - d) / (1 + d);
254*5f39d1b3SJooyung Han       return F::FromDouble(e).raw();
255*5f39d1b3SJooyung Han     }
256*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const257*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
258*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, 0>;
259*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
260*5f39d1b3SJooyung Han       return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
261*5f39d1b3SJooyung Han     }
262*5f39d1b3SJooyung Han   };
263*5f39d1b3SJooyung Han 
264*5f39d1b3SJooyung Han   // Op wrapping tanh
265*5f39d1b3SJooyung Han   template <int tIntegerBits>
266*5f39d1b3SJooyung Han   class TanhOp final : public UnaryOpBase {
267*5f39d1b3SJooyung Han    public:
Tolerance() const268*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 310 : 12; }
ReferenceOp(ScalarType x) const269*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
270*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, tIntegerBits>;
271*5f39d1b3SJooyung Han       using F0 = FixedPoint<ScalarType, 0>;
272*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
273*5f39d1b3SJooyung Han       const double e = std::tanh(d);
274*5f39d1b3SJooyung Han       return F0::FromDouble(e).raw();
275*5f39d1b3SJooyung Han     }
276*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const277*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
278*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, tIntegerBits>;
279*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
280*5f39d1b3SJooyung Han       return tanh(f).raw();
281*5f39d1b3SJooyung Han     }
282*5f39d1b3SJooyung Han   };
283*5f39d1b3SJooyung Han 
284*5f39d1b3SJooyung Han   // Op wrapping one_over_one_plus_x_for_x_in_0_1
285*5f39d1b3SJooyung Han   class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
286*5f39d1b3SJooyung Han    public:
MinInput() const287*5f39d1b3SJooyung Han     ScalarType MinInput() const { return 0; }
Tolerance() const288*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 6 : 5; }
ReferenceOp(ScalarType x) const289*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
290*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, 0>;
291*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
292*5f39d1b3SJooyung Han       const double e = 1 / (1 + d);
293*5f39d1b3SJooyung Han       return F::FromDouble(e).raw();
294*5f39d1b3SJooyung Han     }
295*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const296*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
297*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, 0>;
298*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
299*5f39d1b3SJooyung Han       return one_over_one_plus_x_for_x_in_0_1(f).raw();
300*5f39d1b3SJooyung Han     }
301*5f39d1b3SJooyung Han   };
302*5f39d1b3SJooyung Han 
303*5f39d1b3SJooyung Han   // Op wrapping logistic
304*5f39d1b3SJooyung Han   template <int tIntegerBits>
305*5f39d1b3SJooyung Han   class LogisticOp final : public UnaryOpBase {
306*5f39d1b3SJooyung Han    public:
Tolerance() const307*5f39d1b3SJooyung Han     ScalarType Tolerance() const { return kScalarTypeBits == 32 ? 155 : 6; }
ReferenceOp(ScalarType x) const308*5f39d1b3SJooyung Han     ScalarType ReferenceOp(ScalarType x) const {
309*5f39d1b3SJooyung Han       using F = FixedPoint<ScalarType, tIntegerBits>;
310*5f39d1b3SJooyung Han       using F0 = FixedPoint<ScalarType, 0>;
311*5f39d1b3SJooyung Han       const double d = ToDouble(F::FromRaw(x));
312*5f39d1b3SJooyung Han       const double e = 1 / (1 + std::exp(-d));
313*5f39d1b3SJooyung Han       return F0::FromDouble(e).raw();
314*5f39d1b3SJooyung Han     }
315*5f39d1b3SJooyung Han     template <typename RawType>
Op(RawType x) const316*5f39d1b3SJooyung Han     RawType Op(RawType x) const {
317*5f39d1b3SJooyung Han       using F = FixedPoint<RawType, tIntegerBits>;
318*5f39d1b3SJooyung Han       const F f = F::FromRaw(x);
319*5f39d1b3SJooyung Han       return logistic(f).raw();
320*5f39d1b3SJooyung Han     }
321*5f39d1b3SJooyung Han   };
322*5f39d1b3SJooyung Han 
323*5f39d1b3SJooyung Han   // Tests a given op, on a given list of int32 input values.
324*5f39d1b3SJooyung Han   template <typename tUnaryOpType>
TestUnaryOp(const tUnaryOpType & unary_op,const std::vector<ScalarType> & testvals)325*5f39d1b3SJooyung Han   void TestUnaryOp(const tUnaryOpType& unary_op,
326*5f39d1b3SJooyung Han                    const std::vector<ScalarType>& testvals) {
327*5f39d1b3SJooyung Han     Check(0 == (testvals.size() % kSimdLanes));
328*5f39d1b3SJooyung Han     for (std::size_t i = 0; i < testvals.size(); i += kSimdLanes) {
329*5f39d1b3SJooyung Han       // First, clamp input values accoding to the MinInput() and MaxInput()
330*5f39d1b3SJooyung Han       // bounds returned by the op.
331*5f39d1b3SJooyung Han       ScalarType input[kSimdLanes] = {0};
332*5f39d1b3SJooyung Han       for (std::size_t j = 0; j < kSimdLanes; j++) {
333*5f39d1b3SJooyung Han         const ScalarType raw_input = testvals[i + j];
334*5f39d1b3SJooyung Han         input[j] = std::min(unary_op.MaxInput(),
335*5f39d1b3SJooyung Han                             std::max(unary_op.MinInput(), raw_input));
336*5f39d1b3SJooyung Han       }
337*5f39d1b3SJooyung Han       // Compute reference results and check that the actual results on
338*5f39d1b3SJooyung Han       // scalar inputs agree with them, to the Tolerance() returned by the op.
339*5f39d1b3SJooyung Han       ScalarType reference[kSimdLanes] = {0};
340*5f39d1b3SJooyung Han       ScalarType actual_scalar[kSimdLanes] = {0};
341*5f39d1b3SJooyung Han       for (std::size_t j = 0; j < kSimdLanes; j++) {
342*5f39d1b3SJooyung Han         reference[j] = unary_op.ReferenceOp(input[j]);
343*5f39d1b3SJooyung Han         actual_scalar[j] = unary_op.Op(input[j]);
344*5f39d1b3SJooyung Han         const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
345*5f39d1b3SJooyung Han                                   static_cast<std::int64_t>(reference[j]);
346*5f39d1b3SJooyung Han         if (std::abs(diff) > unary_op.Tolerance()) {
347*5f39d1b3SJooyung Han           fprintf(stderr, "abs(diff) (%" PRId64 ") > tolerance (%d)\n", diff,
348*5f39d1b3SJooyung Han                   unary_op.Tolerance());
349*5f39d1b3SJooyung Han         }
350*5f39d1b3SJooyung Han         Check(std::abs(diff) <= unary_op.Tolerance());
351*5f39d1b3SJooyung Han       }
352*5f39d1b3SJooyung Han       // Check that the actual results on SIMD inputs agree *exactly* with the
353*5f39d1b3SJooyung Han       // actual results on scalar inputs. I.e. SIMD must make absolutely no
354*5f39d1b3SJooyung Han       // difference
355*5f39d1b3SJooyung Han       // to the results, regardless of the fact that both scalar and SIMD
356*5f39d1b3SJooyung Han       // results may differ from the reference results.
357*5f39d1b3SJooyung Han       ScalarType actual_simd[kSimdLanes] = {0};
358*5f39d1b3SJooyung Han       Store<SimdType>(actual_simd, unary_op.Op(Load<SimdType>(input)));
359*5f39d1b3SJooyung Han       for (std::size_t j = 0; j < kSimdLanes; j++) {
360*5f39d1b3SJooyung Han         if (actual_simd[j] != actual_scalar[j]) {
361*5f39d1b3SJooyung Han           fprintf(stderr, "SIMD (%d) != scalar (%d)\n", actual_simd[j],
362*5f39d1b3SJooyung Han                   actual_scalar[j]);
363*5f39d1b3SJooyung Han         }
364*5f39d1b3SJooyung Han         Check(actual_simd[j] == actual_scalar[j]);
365*5f39d1b3SJooyung Han       }
366*5f39d1b3SJooyung Han     }
367*5f39d1b3SJooyung Han   }
368*5f39d1b3SJooyung Han 
369*5f39d1b3SJooyung Han   template <int tIntegerBits>
test_convert(FixedPoint<ScalarType,tIntegerBits> x)370*5f39d1b3SJooyung Han   void test_convert(FixedPoint<ScalarType, tIntegerBits> x) {
371*5f39d1b3SJooyung Han     typedef FixedPoint<ScalarType, tIntegerBits> F;
372*5f39d1b3SJooyung Han     F y = F::FromDouble(ToDouble(x));
373*5f39d1b3SJooyung Han     Check(y == x);
374*5f39d1b3SJooyung Han   }
375*5f39d1b3SJooyung Han 
376*5f39d1b3SJooyung Han   template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(FixedPoint<ScalarType,tIntegerBits_a> a)377*5f39d1b3SJooyung Han   void test_Rescale(FixedPoint<ScalarType, tIntegerBits_a> a) {
378*5f39d1b3SJooyung Han     FixedPoint<ScalarType, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
379*5f39d1b3SJooyung Han     FixedPoint<ScalarType, tIntegerBits_b> expected =
380*5f39d1b3SJooyung Han         FixedPoint<ScalarType, tIntegerBits_b>::FromDouble(ToDouble(a));
381*5f39d1b3SJooyung Han     Check(actual == expected);
382*5f39d1b3SJooyung Han   }
383*5f39d1b3SJooyung Han 
384*5f39d1b3SJooyung Han   template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(const std::vector<ScalarType> & testvals)385*5f39d1b3SJooyung Han   void test_Rescale(const std::vector<ScalarType>& testvals) {
386*5f39d1b3SJooyung Han     for (auto a : testvals) {
387*5f39d1b3SJooyung Han       FixedPoint<ScalarType, tIntegerBits_a> aq;
388*5f39d1b3SJooyung Han       aq.raw() = a;
389*5f39d1b3SJooyung Han       test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
390*5f39d1b3SJooyung Han     }
391*5f39d1b3SJooyung Han   }
392*5f39d1b3SJooyung Han 
393*5f39d1b3SJooyung Han   template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(FixedPoint<ScalarType,tIntegerBits_a> a,FixedPoint<ScalarType,tIntegerBits_b> b)394*5f39d1b3SJooyung Han   void test_mul(FixedPoint<ScalarType, tIntegerBits_a> a,
395*5f39d1b3SJooyung Han                 FixedPoint<ScalarType, tIntegerBits_b> b) {
396*5f39d1b3SJooyung Han     static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
397*5f39d1b3SJooyung Han     using ProductFixedPoint = FixedPoint<ScalarType, ProductIntegerBits>;
398*5f39d1b3SJooyung Han     ProductFixedPoint ab;
399*5f39d1b3SJooyung Han     ab = a * b;
400*5f39d1b3SJooyung Han     double a_double = ToDouble(a);
401*5f39d1b3SJooyung Han     double b_double = ToDouble(b);
402*5f39d1b3SJooyung Han     double ab_double = a_double * b_double;
403*5f39d1b3SJooyung Han     ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
404*5f39d1b3SJooyung Han     std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
405*5f39d1b3SJooyung Han     Check(std::abs(diff) <= 1);
406*5f39d1b3SJooyung Han   }
407*5f39d1b3SJooyung Han 
408*5f39d1b3SJooyung Han   template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(const std::vector<ScalarType> & testvals)409*5f39d1b3SJooyung Han   void test_mul(const std::vector<ScalarType>& testvals) {
410*5f39d1b3SJooyung Han     for (auto a : testvals) {
411*5f39d1b3SJooyung Han       for (auto b : testvals) {
412*5f39d1b3SJooyung Han         FixedPoint<ScalarType, tIntegerBits_a> aq;
413*5f39d1b3SJooyung Han         FixedPoint<ScalarType, tIntegerBits_b> bq;
414*5f39d1b3SJooyung Han         aq.raw() = a;
415*5f39d1b3SJooyung Han         bq.raw() = b;
416*5f39d1b3SJooyung Han         test_mul(aq, bq);
417*5f39d1b3SJooyung Han       }
418*5f39d1b3SJooyung Han     }
419*5f39d1b3SJooyung Han   }
420*5f39d1b3SJooyung Han 
421*5f39d1b3SJooyung Han   template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(FixedPoint<ScalarType,tIntegerBits_a> a)422*5f39d1b3SJooyung Han   void test_ExactMulByPot(FixedPoint<ScalarType, tIntegerBits_a> a) {
423*5f39d1b3SJooyung Han     double x = ToDouble(a) * std::pow(2.0, tExponent);
424*5f39d1b3SJooyung Han     double y = ToDouble(ExactMulByPot<tExponent>(a));
425*5f39d1b3SJooyung Han     Check(x == y);
426*5f39d1b3SJooyung Han   }
427*5f39d1b3SJooyung Han 
428*5f39d1b3SJooyung Han   template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(const std::vector<ScalarType> & testvals)429*5f39d1b3SJooyung Han   void test_ExactMulByPot(const std::vector<ScalarType>& testvals) {
430*5f39d1b3SJooyung Han     for (auto a : testvals) {
431*5f39d1b3SJooyung Han       FixedPoint<ScalarType, tIntegerBits_a> aq;
432*5f39d1b3SJooyung Han       aq.raw() = a;
433*5f39d1b3SJooyung Han       test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
434*5f39d1b3SJooyung Han     }
435*5f39d1b3SJooyung Han   }
436*5f39d1b3SJooyung Han 
437*5f39d1b3SJooyung Han   // Make the list of test values to test each op against.
MakeTestVals()438*5f39d1b3SJooyung Han   std::vector<ScalarType> MakeTestVals() {
439*5f39d1b3SJooyung Han     std::vector<ScalarType> testvals;
440*5f39d1b3SJooyung Han 
441*5f39d1b3SJooyung Han     for (int i = 0; i < kScalarTypeBits - 1; i++) {
442*5f39d1b3SJooyung Han       testvals.push_back((1 << i) - 2);
443*5f39d1b3SJooyung Han       testvals.push_back((1 << i) - 1);
444*5f39d1b3SJooyung Han       testvals.push_back((1 << i));
445*5f39d1b3SJooyung Han       testvals.push_back((1 << i) + 1);
446*5f39d1b3SJooyung Han       testvals.push_back((1 << i) + 2);
447*5f39d1b3SJooyung Han       testvals.push_back(-(1 << i) - 2);
448*5f39d1b3SJooyung Han       testvals.push_back(-(1 << i) - 1);
449*5f39d1b3SJooyung Han       testvals.push_back(-(1 << i));
450*5f39d1b3SJooyung Han       testvals.push_back(-(1 << i) + 1);
451*5f39d1b3SJooyung Han       testvals.push_back(-(1 << i) + 2);
452*5f39d1b3SJooyung Han     }
453*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::min());
454*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::min() + 1);
455*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::min() + 2);
456*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::max() - 2);
457*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::max() - 1);
458*5f39d1b3SJooyung Han     testvals.push_back(std::numeric_limits<ScalarType>::max());
459*5f39d1b3SJooyung Han 
460*5f39d1b3SJooyung Han     std::mt19937 random_engine;
461*5f39d1b3SJooyung Han     std::uniform_int_distribution<ScalarType> uniform_distribution(
462*5f39d1b3SJooyung Han         std::numeric_limits<ScalarType>::min(),
463*5f39d1b3SJooyung Han         std::numeric_limits<ScalarType>::max());
464*5f39d1b3SJooyung Han     for (int i = 0; i < 1000; i++) {
465*5f39d1b3SJooyung Han       testvals.push_back(uniform_distribution(random_engine));
466*5f39d1b3SJooyung Han     }
467*5f39d1b3SJooyung Han 
468*5f39d1b3SJooyung Han     // SIMD tests will require the length of testvals to be a multiple
469*5f39d1b3SJooyung Han     // of SIMD vector size.
470*5f39d1b3SJooyung Han     while (testvals.size() % kSimdLanes) {
471*5f39d1b3SJooyung Han       testvals.push_back(0);
472*5f39d1b3SJooyung Han     }
473*5f39d1b3SJooyung Han 
474*5f39d1b3SJooyung Han     std::sort(testvals.begin(), testvals.end());
475*5f39d1b3SJooyung Han     return testvals;
476*5f39d1b3SJooyung Han   }
477*5f39d1b3SJooyung Han 
RunTests(const char * msg)478*5f39d1b3SJooyung Han   void RunTests(const char* msg) {
479*5f39d1b3SJooyung Han     const std::vector<ScalarType> testvals = MakeTestVals();
480*5f39d1b3SJooyung Han 
481*5f39d1b3SJooyung Han     for (int s = 0; s < kScalarTypeBits; s++) {
482*5f39d1b3SJooyung Han       TestUnaryOp(RoundingDivideByPOTOp(s), testvals);
483*5f39d1b3SJooyung Han     }
484*5f39d1b3SJooyung Han 
485*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1 - kScalarTypeBits>(),
486*5f39d1b3SJooyung Han                 testvals);
487*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2 - kScalarTypeBits>(),
488*5f39d1b3SJooyung Han                 testvals);
489*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3 - kScalarTypeBits>(),
490*5f39d1b3SJooyung Han                 testvals);
491*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<14 - kScalarTypeBits>(),
492*5f39d1b3SJooyung Han                 testvals);
493*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15 - kScalarTypeBits>(),
494*5f39d1b3SJooyung Han                 testvals);
495*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals);
496*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals);
497*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals);
498*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals);
499*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals);
500*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals);
501*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals);
502*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals);
503*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals);
504*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals);
505*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals);
506*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 15>(),
507*5f39d1b3SJooyung Han                 testvals);
508*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 14>(),
509*5f39d1b3SJooyung Han                 testvals);
510*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 3>(),
511*5f39d1b3SJooyung Han                 testvals);
512*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 2>(),
513*5f39d1b3SJooyung Han                 testvals);
514*5f39d1b3SJooyung Han     TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<kScalarTypeBits - 1>(),
515*5f39d1b3SJooyung Han                 testvals);
516*5f39d1b3SJooyung Han 
517*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(), testvals);
518*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals);
519*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals);
520*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals);
521*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals);
522*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals);
523*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals);
524*5f39d1b3SJooyung Han     TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals);
525*5f39d1b3SJooyung Han 
526*5f39d1b3SJooyung Han     TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals);
527*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<0>(), testvals);
528*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<1>(), testvals);
529*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<2>(), testvals);
530*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<3>(), testvals);
531*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<4>(), testvals);
532*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<5>(), testvals);
533*5f39d1b3SJooyung Han     TestUnaryOp(TanhOp<6>(), testvals);
534*5f39d1b3SJooyung Han 
535*5f39d1b3SJooyung Han     TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals);
536*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<0>(), testvals);
537*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<1>(), testvals);
538*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<2>(), testvals);
539*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<3>(), testvals);
540*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<4>(), testvals);
541*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<5>(), testvals);
542*5f39d1b3SJooyung Han     TestUnaryOp(LogisticOp<6>(), testvals);
543*5f39d1b3SJooyung Han 
544*5f39d1b3SJooyung Han     for (auto a : testvals) {
545*5f39d1b3SJooyung Han       FixedPoint<ScalarType, 4> x;
546*5f39d1b3SJooyung Han       x.raw() = a;
547*5f39d1b3SJooyung Han       test_convert(x);
548*5f39d1b3SJooyung Han     }
549*5f39d1b3SJooyung Han 
550*5f39d1b3SJooyung Han     test_mul<0, 0>(testvals);
551*5f39d1b3SJooyung Han     test_mul<0, 1>(testvals);
552*5f39d1b3SJooyung Han     test_mul<2, 0>(testvals);
553*5f39d1b3SJooyung Han     test_mul<1, 1>(testvals);
554*5f39d1b3SJooyung Han     test_mul<4, 4>(testvals);
555*5f39d1b3SJooyung Han     test_mul<3, 5>(testvals);
556*5f39d1b3SJooyung Han     test_mul<7, 2>(testvals);
557*5f39d1b3SJooyung Han     test_mul<kScalarTypeBits / 2 - 1, kScalarTypeBits / 2 - 2>(testvals);
558*5f39d1b3SJooyung Han 
559*5f39d1b3SJooyung Han     test_Rescale<0, 0>(testvals);
560*5f39d1b3SJooyung Han     test_Rescale<0, 1>(testvals);
561*5f39d1b3SJooyung Han     test_Rescale<2, 0>(testvals);
562*5f39d1b3SJooyung Han     test_Rescale<4, 4>(testvals);
563*5f39d1b3SJooyung Han     test_Rescale<4, 5>(testvals);
564*5f39d1b3SJooyung Han     test_Rescale<6, 3>(testvals);
565*5f39d1b3SJooyung Han     test_Rescale<13, 9>(testvals);
566*5f39d1b3SJooyung Han 
567*5f39d1b3SJooyung Han     test_ExactMulByPot<0, 0>(testvals);
568*5f39d1b3SJooyung Han     test_ExactMulByPot<0, 4>(testvals);
569*5f39d1b3SJooyung Han     test_ExactMulByPot<1, 4>(testvals);
570*5f39d1b3SJooyung Han     test_ExactMulByPot<3, 2>(testvals);
571*5f39d1b3SJooyung Han     test_ExactMulByPot<-4, 5>(testvals);
572*5f39d1b3SJooyung Han     test_ExactMulByPot<-2, 6>(testvals);
573*5f39d1b3SJooyung Han 
574*5f39d1b3SJooyung Han     fprintf(stderr, "PASS (%s)\n", msg);
575*5f39d1b3SJooyung Han   }
576*5f39d1b3SJooyung Han };
577*5f39d1b3SJooyung Han 
578*5f39d1b3SJooyung Han }  // end anonymous namespace
579*5f39d1b3SJooyung Han 
580*5f39d1b3SJooyung Han }  // end namespace gemmlowp
581*5f39d1b3SJooyung Han 
main()582*5f39d1b3SJooyung Han int main() {
583*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<std::int32_t>().RunTests("Scalar int32");
584*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<std::int16_t>().RunTests("Scalar int16");
585*5f39d1b3SJooyung Han #ifdef GEMMLOWP_SSE4
586*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<__m128i>().RunTests("SSE4 __m128i = int32x4");
587*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<gemmlowp::int16x8_m128i>().RunTests(
588*5f39d1b3SJooyung Han       "SSE4 __m128i = int16x8");
589*5f39d1b3SJooyung Han #endif
590*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
591*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<int32x4_t>().RunTests("NEON int32x4_t");
592*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<int16x8_t>().RunTests("NEON int16x8_t");
593*5f39d1b3SJooyung Han #endif
594*5f39d1b3SJooyung Han #ifdef GEMMLOWP_MSA
595*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<v4i32>().RunTests("MSA v4i32");
596*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<v8i16>().RunTests("MSA v8i16");
597*5f39d1b3SJooyung Han #endif
598*5f39d1b3SJooyung Han #ifdef GEMMLOWP_AVX2
599*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<__m256i>().RunTests("AVX __m256i");
600*5f39d1b3SJooyung Han   gemmlowp::TestFixedPoint<gemmlowp::int16x16_m256i>().RunTests(
601*5f39d1b3SJooyung Han       "AVX2 __m256i = int16x16");
602*5f39d1b3SJooyung Han #endif
603*5f39d1b3SJooyung Han }
604