xref: /aosp_15_r20/external/gemmlowp/fixedpoint/fixedpoint_sse.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 Google Inc. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // fixedpoint_SSE.h: optimized SSE specializations of the templates
16*5f39d1b3SJooyung Han // in fixedpoint.h.
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #include <smmintrin.h>
22*5f39d1b3SJooyung Han #include "fixedpoint.h"
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han // SSE intrinsics are not finely typed: there is a single __m128i vector
27*5f39d1b3SJooyung Han // type that does not distinguish between "int32x4" and "int16x8" use
28*5f39d1b3SJooyung Han // cases, unlike the NEON equivalents. Because we had initially focused
29*5f39d1b3SJooyung Han // on int32x4, we did not pay attention and specialized these fixedpoint
30*5f39d1b3SJooyung Han // templates directly for __m128i hardcoding the int32x4 semantics,
31*5f39d1b3SJooyung Han // not leaving room for int16x8 semantics. Amending that by adding a separate
32*5f39d1b3SJooyung Han // data type, int16x8_m128i, that wraps __m128i while being a separate
33*5f39d1b3SJooyung Han // type.
34*5f39d1b3SJooyung Han struct int16x8_m128i {
35*5f39d1b3SJooyung Han   __m128i v;
36*5f39d1b3SJooyung Han };
37*5f39d1b3SJooyung Han 
38*5f39d1b3SJooyung Han // Keep int16x8_m128i trivially constructible/destructible and provide
39*5f39d1b3SJooyung Han // easily optimized helper function.
to_int16x8_m128i(__m128i w)40*5f39d1b3SJooyung Han inline int16x8_m128i to_int16x8_m128i(__m128i w) {
41*5f39d1b3SJooyung Han   int16x8_m128i r;
42*5f39d1b3SJooyung Han   r.v = w;
43*5f39d1b3SJooyung Han   return r;
44*5f39d1b3SJooyung Han }
45*5f39d1b3SJooyung Han 
46*5f39d1b3SJooyung Han template <>
47*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<__m128i> {
48*5f39d1b3SJooyung Han   typedef std::int32_t ScalarRawType;
49*5f39d1b3SJooyung Han   static constexpr int kLanes = 4;
50*5f39d1b3SJooyung Han };
51*5f39d1b3SJooyung Han 
52*5f39d1b3SJooyung Han template <>
53*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<int16x8_m128i> {
54*5f39d1b3SJooyung Han   typedef std::int16_t ScalarRawType;
55*5f39d1b3SJooyung Han   static constexpr int kLanes = 8;
56*5f39d1b3SJooyung Han };
57*5f39d1b3SJooyung Han 
58*5f39d1b3SJooyung Han template <>
59*5f39d1b3SJooyung Han inline __m128i BitAnd(__m128i a, __m128i b) {
60*5f39d1b3SJooyung Han   return _mm_and_si128(a, b);
61*5f39d1b3SJooyung Han }
62*5f39d1b3SJooyung Han 
63*5f39d1b3SJooyung Han template <>
64*5f39d1b3SJooyung Han inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) {
65*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_and_si128(a.v, b.v));
66*5f39d1b3SJooyung Han }
67*5f39d1b3SJooyung Han 
68*5f39d1b3SJooyung Han template <>
69*5f39d1b3SJooyung Han inline __m128i BitOr(__m128i a, __m128i b) {
70*5f39d1b3SJooyung Han   return _mm_or_si128(a, b);
71*5f39d1b3SJooyung Han }
72*5f39d1b3SJooyung Han 
73*5f39d1b3SJooyung Han template <>
74*5f39d1b3SJooyung Han inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) {
75*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_or_si128(a.v, b.v));
76*5f39d1b3SJooyung Han }
77*5f39d1b3SJooyung Han 
78*5f39d1b3SJooyung Han template <>
79*5f39d1b3SJooyung Han inline __m128i BitXor(__m128i a, __m128i b) {
80*5f39d1b3SJooyung Han   return _mm_xor_si128(a, b);
81*5f39d1b3SJooyung Han }
82*5f39d1b3SJooyung Han 
83*5f39d1b3SJooyung Han template <>
84*5f39d1b3SJooyung Han inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) {
85*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_xor_si128(a.v, b.v));
86*5f39d1b3SJooyung Han }
87*5f39d1b3SJooyung Han 
88*5f39d1b3SJooyung Han template <>
89*5f39d1b3SJooyung Han inline __m128i BitNot(__m128i a) {
90*5f39d1b3SJooyung Han   return _mm_andnot_si128(a, _mm_set1_epi32(-1));
91*5f39d1b3SJooyung Han }
92*5f39d1b3SJooyung Han 
93*5f39d1b3SJooyung Han template <>
94*5f39d1b3SJooyung Han inline int16x8_m128i BitNot(int16x8_m128i a) {
95*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
96*5f39d1b3SJooyung Han }
97*5f39d1b3SJooyung Han 
98*5f39d1b3SJooyung Han template <>
99*5f39d1b3SJooyung Han inline __m128i Add(__m128i a, __m128i b) {
100*5f39d1b3SJooyung Han   return _mm_add_epi32(a, b);
101*5f39d1b3SJooyung Han }
102*5f39d1b3SJooyung Han 
103*5f39d1b3SJooyung Han template <>
104*5f39d1b3SJooyung Han inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) {
105*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_add_epi16(a.v, b.v));
106*5f39d1b3SJooyung Han }
107*5f39d1b3SJooyung Han 
108*5f39d1b3SJooyung Han template <>
109*5f39d1b3SJooyung Han inline __m128i Mul(__m128i a, __m128i b) {
110*5f39d1b3SJooyung Han   return _mm_mullo_epi32(a, b);
111*5f39d1b3SJooyung Han }
112*5f39d1b3SJooyung Han 
113*5f39d1b3SJooyung Han template <>
114*5f39d1b3SJooyung Han inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) {
115*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
116*5f39d1b3SJooyung Han }
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han template <>
119*5f39d1b3SJooyung Han inline __m128i Sub(__m128i a, __m128i b) {
120*5f39d1b3SJooyung Han   return _mm_sub_epi32(a, b);
121*5f39d1b3SJooyung Han }
122*5f39d1b3SJooyung Han 
123*5f39d1b3SJooyung Han template <>
124*5f39d1b3SJooyung Han inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) {
125*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_sub_epi16(a.v, b.v));
126*5f39d1b3SJooyung Han }
127*5f39d1b3SJooyung Han 
128*5f39d1b3SJooyung Han template <>
129*5f39d1b3SJooyung Han inline __m128i Neg(__m128i a) {
130*5f39d1b3SJooyung Han   return _mm_sign_epi32(a, _mm_set1_epi32(-1));
131*5f39d1b3SJooyung Han }
132*5f39d1b3SJooyung Han 
133*5f39d1b3SJooyung Han template <>
134*5f39d1b3SJooyung Han inline int16x8_m128i Neg(int16x8_m128i a) {
135*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
136*5f39d1b3SJooyung Han }
137*5f39d1b3SJooyung Han 
138*5f39d1b3SJooyung Han template <>
139*5f39d1b3SJooyung Han inline __m128i ShiftLeft(__m128i a, int offset) {
140*5f39d1b3SJooyung Han   return _mm_slli_epi32(a, offset);
141*5f39d1b3SJooyung Han }
142*5f39d1b3SJooyung Han 
143*5f39d1b3SJooyung Han template <>
144*5f39d1b3SJooyung Han inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) {
145*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_slli_epi16(a.v, offset));
146*5f39d1b3SJooyung Han }
147*5f39d1b3SJooyung Han 
148*5f39d1b3SJooyung Han template <>
149*5f39d1b3SJooyung Han inline __m128i ShiftRight(__m128i a, int offset) {
150*5f39d1b3SJooyung Han   return _mm_srai_epi32(a, offset);
151*5f39d1b3SJooyung Han }
152*5f39d1b3SJooyung Han 
153*5f39d1b3SJooyung Han template <>
154*5f39d1b3SJooyung Han inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) {
155*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_srai_epi16(a.v, offset));
156*5f39d1b3SJooyung Han }
157*5f39d1b3SJooyung Han 
158*5f39d1b3SJooyung Han template <>
159*5f39d1b3SJooyung Han inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val,
160*5f39d1b3SJooyung Han                                __m128i else_val) {
161*5f39d1b3SJooyung Han   // borrowed from Intel's arm_neon_sse.h header.
162*5f39d1b3SJooyung Han   return _mm_or_si128(_mm_and_si128(if_mask, then_val),
163*5f39d1b3SJooyung Han                       _mm_andnot_si128(if_mask, else_val));
164*5f39d1b3SJooyung Han }
165*5f39d1b3SJooyung Han 
166*5f39d1b3SJooyung Han template <>
167*5f39d1b3SJooyung Han inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask,
168*5f39d1b3SJooyung Han                                      int16x8_m128i then_val,
169*5f39d1b3SJooyung Han                                      int16x8_m128i else_val) {
170*5f39d1b3SJooyung Han   // borrowed from Intel's arm_neon_sse.h header.
171*5f39d1b3SJooyung Han   return to_int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
172*5f39d1b3SJooyung Han }
173*5f39d1b3SJooyung Han 
174*5f39d1b3SJooyung Han template <>
175*5f39d1b3SJooyung Han inline __m128i MaskIfEqual(__m128i a, __m128i b) {
176*5f39d1b3SJooyung Han   return _mm_cmpeq_epi32(a, b);
177*5f39d1b3SJooyung Han }
178*5f39d1b3SJooyung Han 
179*5f39d1b3SJooyung Han template <>
180*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) {
181*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
182*5f39d1b3SJooyung Han }
183*5f39d1b3SJooyung Han 
184*5f39d1b3SJooyung Han template <>
185*5f39d1b3SJooyung Han inline __m128i MaskIfNotEqual(__m128i a, __m128i b) {
186*5f39d1b3SJooyung Han   return BitNot(MaskIfEqual(a, b));
187*5f39d1b3SJooyung Han }
188*5f39d1b3SJooyung Han 
189*5f39d1b3SJooyung Han template <>
190*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) {
191*5f39d1b3SJooyung Han   return BitNot(MaskIfEqual(a, b));
192*5f39d1b3SJooyung Han }
193*5f39d1b3SJooyung Han 
194*5f39d1b3SJooyung Han template <>
195*5f39d1b3SJooyung Han inline __m128i MaskIfZero(__m128i a) {
196*5f39d1b3SJooyung Han   return MaskIfEqual(a, _mm_set1_epi32(0));
197*5f39d1b3SJooyung Han }
198*5f39d1b3SJooyung Han 
199*5f39d1b3SJooyung Han template <>
200*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfZero(int16x8_m128i a) {
201*5f39d1b3SJooyung Han   return MaskIfEqual(a, to_int16x8_m128i(_mm_set1_epi16(0)));
202*5f39d1b3SJooyung Han }
203*5f39d1b3SJooyung Han 
204*5f39d1b3SJooyung Han template <>
205*5f39d1b3SJooyung Han inline __m128i MaskIfNonZero(__m128i a) {
206*5f39d1b3SJooyung Han   return MaskIfNotEqual(a, _mm_set1_epi32(0));
207*5f39d1b3SJooyung Han }
208*5f39d1b3SJooyung Han 
209*5f39d1b3SJooyung Han template <>
210*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) {
211*5f39d1b3SJooyung Han   return MaskIfNotEqual(a, to_int16x8_m128i(_mm_set1_epi16(0)));
212*5f39d1b3SJooyung Han }
213*5f39d1b3SJooyung Han 
214*5f39d1b3SJooyung Han template <>
215*5f39d1b3SJooyung Han inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) {
216*5f39d1b3SJooyung Han   return _mm_cmpgt_epi32(a, b);
217*5f39d1b3SJooyung Han }
218*5f39d1b3SJooyung Han 
219*5f39d1b3SJooyung Han template <>
220*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) {
221*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
222*5f39d1b3SJooyung Han }
223*5f39d1b3SJooyung Han 
224*5f39d1b3SJooyung Han template <>
225*5f39d1b3SJooyung Han inline __m128i MaskIfLessThan(__m128i a, __m128i b) {
226*5f39d1b3SJooyung Han   return _mm_cmplt_epi32(a, b);
227*5f39d1b3SJooyung Han }
228*5f39d1b3SJooyung Han 
229*5f39d1b3SJooyung Han template <>
230*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) {
231*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
232*5f39d1b3SJooyung Han }
233*5f39d1b3SJooyung Han 
234*5f39d1b3SJooyung Han template <>
235*5f39d1b3SJooyung Han inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) {
236*5f39d1b3SJooyung Han   return BitNot(MaskIfLessThan(a, b));
237*5f39d1b3SJooyung Han }
238*5f39d1b3SJooyung Han 
239*5f39d1b3SJooyung Han template <>
240*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a,
241*5f39d1b3SJooyung Han                                               int16x8_m128i b) {
242*5f39d1b3SJooyung Han   return BitNot(MaskIfLessThan(a, b));
243*5f39d1b3SJooyung Han }
244*5f39d1b3SJooyung Han 
245*5f39d1b3SJooyung Han template <>
246*5f39d1b3SJooyung Han inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) {
247*5f39d1b3SJooyung Han   return BitNot(MaskIfGreaterThan(a, b));
248*5f39d1b3SJooyung Han }
249*5f39d1b3SJooyung Han 
250*5f39d1b3SJooyung Han template <>
251*5f39d1b3SJooyung Han inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) {
252*5f39d1b3SJooyung Han   return BitNot(MaskIfGreaterThan(a, b));
253*5f39d1b3SJooyung Han }
254*5f39d1b3SJooyung Han 
255*5f39d1b3SJooyung Han /* Assumptions:
256*5f39d1b3SJooyung Han    - All and Any are used on masks.
257*5f39d1b3SJooyung Han    - masks are all_ones for true lanes, all_zeroes otherwise.
258*5f39d1b3SJooyung Han Hence, All means all 128bits set, and Any means any bit set.
259*5f39d1b3SJooyung Han */
260*5f39d1b3SJooyung Han 
261*5f39d1b3SJooyung Han template <>
262*5f39d1b3SJooyung Han inline bool All(__m128i a) {
263*5f39d1b3SJooyung Han   return _mm_testc_si128(a, a);
264*5f39d1b3SJooyung Han }
265*5f39d1b3SJooyung Han 
266*5f39d1b3SJooyung Han template <>
267*5f39d1b3SJooyung Han inline bool All(int16x8_m128i a) {
268*5f39d1b3SJooyung Han   return _mm_testc_si128(a.v, a.v);
269*5f39d1b3SJooyung Han }
270*5f39d1b3SJooyung Han 
271*5f39d1b3SJooyung Han template <>
272*5f39d1b3SJooyung Han inline bool Any(__m128i a) {
273*5f39d1b3SJooyung Han   return !_mm_testz_si128(a, a);
274*5f39d1b3SJooyung Han }
275*5f39d1b3SJooyung Han 
276*5f39d1b3SJooyung Han template <>
277*5f39d1b3SJooyung Han inline bool Any(int16x8_m128i a) {
278*5f39d1b3SJooyung Han   return !_mm_testz_si128(a.v, a.v);
279*5f39d1b3SJooyung Han }
280*5f39d1b3SJooyung Han 
281*5f39d1b3SJooyung Han template <>
282*5f39d1b3SJooyung Han inline __m128i RoundingHalfSum(__m128i a, __m128i b) {
283*5f39d1b3SJooyung Han   /* __m128i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
284*5f39d1b3SJooyung Han   /* We divide the inputs before the add to avoid the overflow and costly test
285*5f39d1b3SJooyung Han    */
286*5f39d1b3SJooyung Han   /* of checking if an overflow occured on signed add */
287*5f39d1b3SJooyung Han   /* round_bit_mask = _mm_set1_epi32(1); */
288*5f39d1b3SJooyung Han   /* a_over_2 = _mm_srai_epi32(a, 1); */
289*5f39d1b3SJooyung Han   /* b_over_2 = _mm_srai_epi32(b, 1); */
290*5f39d1b3SJooyung Han   /* sum = Add(a_over_2, b_over_2); */
291*5f39d1b3SJooyung Han   /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
292*5f39d1b3SJooyung Han   /* return Add(sum, round_bit); */
293*5f39d1b3SJooyung Han 
294*5f39d1b3SJooyung Han   /* Other possibility detecting overflow and xor the sign if an overflow
295*5f39d1b3SJooyung Han    * happened*/
296*5f39d1b3SJooyung Han   __m128i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
297*5f39d1b3SJooyung Han   one = _mm_set1_epi32(1);
298*5f39d1b3SJooyung Han   sign_bit_mask = _mm_set1_epi32(0x80000000);
299*5f39d1b3SJooyung Han   sum = Add(a, b);
300*5f39d1b3SJooyung Han   rounded_half_sum = _mm_srai_epi32(Add(sum, one), 1);
301*5f39d1b3SJooyung Han   overflow =
302*5f39d1b3SJooyung Han       BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
303*5f39d1b3SJooyung Han              sign_bit_mask);
304*5f39d1b3SJooyung Han   result = BitXor(rounded_half_sum, overflow);
305*5f39d1b3SJooyung Han   return result;
306*5f39d1b3SJooyung Han }
307*5f39d1b3SJooyung Han 
308*5f39d1b3SJooyung Han template <>
309*5f39d1b3SJooyung Han inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) {
310*5f39d1b3SJooyung Han   // Idea: go to unsigned to use _mm_avg_epu16,
311*5f39d1b3SJooyung Han   // borrowed from Intel's arm_neon_sse.h header.
312*5f39d1b3SJooyung Han   __m128i constant_neg_32768 = _mm_set1_epi16(-32768);
313*5f39d1b3SJooyung Han   __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768);
314*5f39d1b3SJooyung Han   __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768);
315*5f39d1b3SJooyung Han   __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned);
316*5f39d1b3SJooyung Han   __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768);
317*5f39d1b3SJooyung Han   return to_int16x8_m128i(avg);
318*5f39d1b3SJooyung Han }
319*5f39d1b3SJooyung Han 
320*5f39d1b3SJooyung Han template <>
321*5f39d1b3SJooyung Han inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) {
322*5f39d1b3SJooyung Han   __m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
323*5f39d1b3SJooyung Han   __m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
324*5f39d1b3SJooyung Han   __m128i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
325*5f39d1b3SJooyung Han   __m128i nudge;
326*5f39d1b3SJooyung Han 
327*5f39d1b3SJooyung Han   // saturation only happen if a == b == INT_MIN
328*5f39d1b3SJooyung Han   min = _mm_set1_epi32(std::numeric_limits<std::int32_t>::min());
329*5f39d1b3SJooyung Han   saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
330*5f39d1b3SJooyung Han 
331*5f39d1b3SJooyung Han   // a = a0 | a1 | a2 | a3
332*5f39d1b3SJooyung Han   // b = b0 | b1 | b2 | b3
333*5f39d1b3SJooyung Han   a0_a2 = a;
334*5f39d1b3SJooyung Han   a1_a3 = _mm_srli_si128(a, 4);
335*5f39d1b3SJooyung Han   b0_b2 = b;
336*5f39d1b3SJooyung Han   b1_b3 = _mm_srli_si128(b, 4);
337*5f39d1b3SJooyung Han 
338*5f39d1b3SJooyung Han   a0b0_a2b2 = _mm_mul_epi32(a0_a2, b0_b2);
339*5f39d1b3SJooyung Han   a1b1_a3b3 = _mm_mul_epi32(a1_a3, b1_b3);
340*5f39d1b3SJooyung Han 
341*5f39d1b3SJooyung Han   // do the rounding and take into account that it will be doubled
342*5f39d1b3SJooyung Han   nudge = _mm_set1_epi64x(1 << 30);
343*5f39d1b3SJooyung Han   a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge);
344*5f39d1b3SJooyung Han   a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge);
345*5f39d1b3SJooyung Han 
346*5f39d1b3SJooyung Han   // do the doubling
347*5f39d1b3SJooyung Han   a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1);
348*5f39d1b3SJooyung Han   a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1);
349*5f39d1b3SJooyung Han 
350*5f39d1b3SJooyung Han   // get the high part of the products
351*5f39d1b3SJooyung Han   result = _mm_blend_epi16(_mm_srli_si128(a0b0_a2b2_rounded_2x, 4),
352*5f39d1b3SJooyung Han                            a1b1_a3b3_rounded_2x, 0xcc);
353*5f39d1b3SJooyung Han 
354*5f39d1b3SJooyung Han   // saturate those which overflowed
355*5f39d1b3SJooyung Han   return SelectUsingMask(saturation_mask, min, result);
356*5f39d1b3SJooyung Han }
357*5f39d1b3SJooyung Han 
358*5f39d1b3SJooyung Han template <>
359*5f39d1b3SJooyung Han inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a,
360*5f39d1b3SJooyung Han                                                        int16x8_m128i b) {
361*5f39d1b3SJooyung Han   // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation,
362*5f39d1b3SJooyung Han   // borrowed from Intel's arm_neon_sse.h header.
363*5f39d1b3SJooyung Han   __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v);
364*5f39d1b3SJooyung Han   __m128i saturation_mask =
365*5f39d1b3SJooyung Han       _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000));
366*5f39d1b3SJooyung Han   __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask);
367*5f39d1b3SJooyung Han   return to_int16x8_m128i(result);
368*5f39d1b3SJooyung Han }
369*5f39d1b3SJooyung Han 
370*5f39d1b3SJooyung Han template <>
371*5f39d1b3SJooyung Han inline __m128i Dup<__m128i>(std::int32_t x) {
372*5f39d1b3SJooyung Han   return _mm_set1_epi32(x);
373*5f39d1b3SJooyung Han }
374*5f39d1b3SJooyung Han 
375*5f39d1b3SJooyung Han template <>
376*5f39d1b3SJooyung Han inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) {
377*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_set1_epi16(x));
378*5f39d1b3SJooyung Han }
379*5f39d1b3SJooyung Han 
380*5f39d1b3SJooyung Han // So far this is only needed for int16.
381*5f39d1b3SJooyung Han template <>
382*5f39d1b3SJooyung Han inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) {
383*5f39d1b3SJooyung Han   return to_int16x8_m128i(_mm_adds_epi16(a.v, b.v));
384*5f39d1b3SJooyung Han }
385*5f39d1b3SJooyung Han 
386*5f39d1b3SJooyung Han }  // end namespace gemmlowp
387*5f39d1b3SJooyung Han 
388*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
389