xref: /aosp_15_r20/external/gemmlowp/fixedpoint/fixedpoint_avx.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han //
2*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
3*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
4*5f39d1b3SJooyung Han // You may obtain a copy of the License at
5*5f39d1b3SJooyung Han //
6*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
7*5f39d1b3SJooyung Han //
8*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
9*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
10*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
12*5f39d1b3SJooyung Han // limitations under the License.
13*5f39d1b3SJooyung Han 
14*5f39d1b3SJooyung Han // fixedpoint_avx.h: optimized avx specializations of the templates
15*5f39d1b3SJooyung Han // in fixedpoint.h.
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #include <immintrin.h>
21*5f39d1b3SJooyung Han #include "fixedpoint.h"
22*5f39d1b3SJooyung Han #include "fixedpoint_sse.h"
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han struct int16x16_m256i {
27*5f39d1b3SJooyung Han   __m256i v;
28*5f39d1b3SJooyung Han };
29*5f39d1b3SJooyung Han 
30*5f39d1b3SJooyung Han // Keep int16x16_m256i trivially constructible/destructible and provide
31*5f39d1b3SJooyung Han // easily optimized helper function.
to_int16x16_m256i(__m256i w)32*5f39d1b3SJooyung Han inline int16x16_m256i to_int16x16_m256i(__m256i w) {
33*5f39d1b3SJooyung Han   int16x16_m256i r;
34*5f39d1b3SJooyung Han   r.v = w;
35*5f39d1b3SJooyung Han   return r;
36*5f39d1b3SJooyung Han }
37*5f39d1b3SJooyung Han 
38*5f39d1b3SJooyung Han template <>
39*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<__m256i> {
40*5f39d1b3SJooyung Han   typedef std::int32_t ScalarRawType;
41*5f39d1b3SJooyung Han   // TODO: This can actually support up to 8 lanes, so we should either
42*5f39d1b3SJooyung Han   // change to 8 or create int32x8_m256i struct to handle that case.
43*5f39d1b3SJooyung Han   static const int kLanes = 4;
44*5f39d1b3SJooyung Han };
45*5f39d1b3SJooyung Han 
46*5f39d1b3SJooyung Han template <>
47*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<int16x16_m256i> {
48*5f39d1b3SJooyung Han   typedef std::int16_t ScalarRawType;
49*5f39d1b3SJooyung Han   static const int kLanes = 16;
50*5f39d1b3SJooyung Han };
51*5f39d1b3SJooyung Han 
52*5f39d1b3SJooyung Han template <>
53*5f39d1b3SJooyung Han inline __m256i BitAnd(__m256i a, __m256i b) {
54*5f39d1b3SJooyung Han   return _mm256_and_si256(a, b);
55*5f39d1b3SJooyung Han }
56*5f39d1b3SJooyung Han 
57*5f39d1b3SJooyung Han template <>
58*5f39d1b3SJooyung Han inline int16x16_m256i BitAnd(int16x16_m256i a, int16x16_m256i b) {
59*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_and_si256(a.v, b.v));
60*5f39d1b3SJooyung Han }
61*5f39d1b3SJooyung Han 
62*5f39d1b3SJooyung Han template <>
63*5f39d1b3SJooyung Han inline __m256i BitOr(__m256i a, __m256i b) {
64*5f39d1b3SJooyung Han   return _mm256_or_si256(a, b);
65*5f39d1b3SJooyung Han }
66*5f39d1b3SJooyung Han 
67*5f39d1b3SJooyung Han template <>
68*5f39d1b3SJooyung Han inline int16x16_m256i BitOr(int16x16_m256i a, int16x16_m256i b) {
69*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_or_si256(a.v, b.v));
70*5f39d1b3SJooyung Han }
71*5f39d1b3SJooyung Han 
72*5f39d1b3SJooyung Han template <>
73*5f39d1b3SJooyung Han inline __m256i BitXor(__m256i a, __m256i b) {
74*5f39d1b3SJooyung Han   return _mm256_xor_si256(a, b);
75*5f39d1b3SJooyung Han }
76*5f39d1b3SJooyung Han 
77*5f39d1b3SJooyung Han template <>
78*5f39d1b3SJooyung Han inline int16x16_m256i BitXor(int16x16_m256i a, int16x16_m256i b) {
79*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_xor_si256(a.v, b.v));
80*5f39d1b3SJooyung Han }
81*5f39d1b3SJooyung Han 
82*5f39d1b3SJooyung Han template <>
83*5f39d1b3SJooyung Han inline __m256i BitNot(__m256i a) {
84*5f39d1b3SJooyung Han   return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
85*5f39d1b3SJooyung Han }
86*5f39d1b3SJooyung Han 
87*5f39d1b3SJooyung Han template <>
88*5f39d1b3SJooyung Han inline int16x16_m256i BitNot(int16x16_m256i a) {
89*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_andnot_si256(a.v, _mm256_set1_epi16(-1)));
90*5f39d1b3SJooyung Han }
91*5f39d1b3SJooyung Han 
92*5f39d1b3SJooyung Han template <>
93*5f39d1b3SJooyung Han inline __m256i Add(__m256i a, __m256i b) {
94*5f39d1b3SJooyung Han   return _mm256_add_epi32(a, b);
95*5f39d1b3SJooyung Han }
96*5f39d1b3SJooyung Han 
97*5f39d1b3SJooyung Han template <>
98*5f39d1b3SJooyung Han inline int16x16_m256i Add(int16x16_m256i a, int16x16_m256i b) {
99*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_add_epi16(a.v, b.v));
100*5f39d1b3SJooyung Han }
101*5f39d1b3SJooyung Han 
102*5f39d1b3SJooyung Han template <>
103*5f39d1b3SJooyung Han inline __m256i Mul(__m256i a, __m256i b) {
104*5f39d1b3SJooyung Han   return _mm256_mullo_epi32(a, b);
105*5f39d1b3SJooyung Han }
106*5f39d1b3SJooyung Han 
107*5f39d1b3SJooyung Han template <>
108*5f39d1b3SJooyung Han inline int16x16_m256i Mul(int16x16_m256i a, int16x16_m256i b) {
109*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_mullo_epi16(a.v, b.v));
110*5f39d1b3SJooyung Han }
111*5f39d1b3SJooyung Han 
112*5f39d1b3SJooyung Han template <>
113*5f39d1b3SJooyung Han inline __m256i Sub(__m256i a, __m256i b) {
114*5f39d1b3SJooyung Han   return _mm256_sub_epi32(a, b);
115*5f39d1b3SJooyung Han }
116*5f39d1b3SJooyung Han 
117*5f39d1b3SJooyung Han template <>
118*5f39d1b3SJooyung Han inline int16x16_m256i Sub(int16x16_m256i a, int16x16_m256i b) {
119*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_sub_epi16(a.v, b.v));
120*5f39d1b3SJooyung Han }
121*5f39d1b3SJooyung Han 
122*5f39d1b3SJooyung Han template <>
123*5f39d1b3SJooyung Han inline __m256i Neg(__m256i a) {
124*5f39d1b3SJooyung Han   return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
125*5f39d1b3SJooyung Han }
126*5f39d1b3SJooyung Han 
127*5f39d1b3SJooyung Han template <>
128*5f39d1b3SJooyung Han inline int16x16_m256i Neg(int16x16_m256i a) {
129*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_sign_epi16(a.v, _mm256_set1_epi16(-1)));
130*5f39d1b3SJooyung Han }
131*5f39d1b3SJooyung Han 
132*5f39d1b3SJooyung Han template <>
133*5f39d1b3SJooyung Han inline __m256i ShiftLeft(__m256i a, int offset) {
134*5f39d1b3SJooyung Han   return _mm256_slli_epi32(a, offset);
135*5f39d1b3SJooyung Han }
136*5f39d1b3SJooyung Han 
137*5f39d1b3SJooyung Han template <>
138*5f39d1b3SJooyung Han inline int16x16_m256i ShiftLeft(int16x16_m256i a, int offset) {
139*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_slli_epi16(a.v, offset));
140*5f39d1b3SJooyung Han }
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han template <>
143*5f39d1b3SJooyung Han inline __m256i ShiftRight(__m256i a, int offset) {
144*5f39d1b3SJooyung Han   return _mm256_srai_epi32(a, offset);
145*5f39d1b3SJooyung Han }
146*5f39d1b3SJooyung Han 
147*5f39d1b3SJooyung Han template <>
148*5f39d1b3SJooyung Han inline int16x16_m256i ShiftRight(int16x16_m256i a, int offset) {
149*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_srai_epi16(a.v, offset));
150*5f39d1b3SJooyung Han }
151*5f39d1b3SJooyung Han 
152*5f39d1b3SJooyung Han template <>
153*5f39d1b3SJooyung Han inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
154*5f39d1b3SJooyung Han                                __m256i else_val) {
155*5f39d1b3SJooyung Han   return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
156*5f39d1b3SJooyung Han                                               _mm256_castsi256_ps(then_val),
157*5f39d1b3SJooyung Han                                               _mm256_castsi256_ps(if_mask)));
158*5f39d1b3SJooyung Han }
159*5f39d1b3SJooyung Han 
160*5f39d1b3SJooyung Han template <>
161*5f39d1b3SJooyung Han inline int16x16_m256i SelectUsingMask(int16x16_m256i if_mask,
162*5f39d1b3SJooyung Han                                       int16x16_m256i then_val,
163*5f39d1b3SJooyung Han                                       int16x16_m256i else_val) {
164*5f39d1b3SJooyung Han   // Borrowed from Intel's arm_neon_sse.h header.
165*5f39d1b3SJooyung Han   return to_int16x16_m256i(
166*5f39d1b3SJooyung Han       _mm256_or_si256(_mm256_and_si256(if_mask.v, then_val.v),
167*5f39d1b3SJooyung Han                       _mm256_andnot_si256(if_mask.v, else_val.v)));
168*5f39d1b3SJooyung Han }
169*5f39d1b3SJooyung Han 
170*5f39d1b3SJooyung Han template <>
171*5f39d1b3SJooyung Han inline __m256i MaskIfEqual(__m256i a, __m256i b) {
172*5f39d1b3SJooyung Han   return _mm256_cmpeq_epi32(a, b);
173*5f39d1b3SJooyung Han }
174*5f39d1b3SJooyung Han 
175*5f39d1b3SJooyung Han template <>
176*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfEqual(int16x16_m256i a, int16x16_m256i b) {
177*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_cmpeq_epi16(a.v, b.v));
178*5f39d1b3SJooyung Han }
179*5f39d1b3SJooyung Han 
180*5f39d1b3SJooyung Han template <>
181*5f39d1b3SJooyung Han inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
182*5f39d1b3SJooyung Han   return BitNot(MaskIfEqual(a, b));
183*5f39d1b3SJooyung Han }
184*5f39d1b3SJooyung Han 
185*5f39d1b3SJooyung Han template <>
186*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfNotEqual(int16x16_m256i a, int16x16_m256i b) {
187*5f39d1b3SJooyung Han   return BitNot(MaskIfEqual(a, b));
188*5f39d1b3SJooyung Han }
189*5f39d1b3SJooyung Han 
190*5f39d1b3SJooyung Han template <>
191*5f39d1b3SJooyung Han inline __m256i MaskIfZero(__m256i a) {
192*5f39d1b3SJooyung Han   return MaskIfEqual(a, _mm256_set1_epi32(0));
193*5f39d1b3SJooyung Han }
194*5f39d1b3SJooyung Han 
195*5f39d1b3SJooyung Han template <>
196*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfZero(int16x16_m256i a) {
197*5f39d1b3SJooyung Han   return MaskIfEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0)));
198*5f39d1b3SJooyung Han }
199*5f39d1b3SJooyung Han 
200*5f39d1b3SJooyung Han template <>
201*5f39d1b3SJooyung Han inline __m256i MaskIfNonZero(__m256i a) {
202*5f39d1b3SJooyung Han   return MaskIfNotEqual(a, _mm256_set1_epi32(0));
203*5f39d1b3SJooyung Han }
204*5f39d1b3SJooyung Han 
205*5f39d1b3SJooyung Han template <>
206*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfNonZero(int16x16_m256i a) {
207*5f39d1b3SJooyung Han   return MaskIfNotEqual(a, to_int16x16_m256i(_mm256_set1_epi16(0)));
208*5f39d1b3SJooyung Han }
209*5f39d1b3SJooyung Han 
210*5f39d1b3SJooyung Han template <>
211*5f39d1b3SJooyung Han inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
212*5f39d1b3SJooyung Han   return _mm256_cmpgt_epi32(a, b);
213*5f39d1b3SJooyung Han }
214*5f39d1b3SJooyung Han 
215*5f39d1b3SJooyung Han template <>
216*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfGreaterThan(int16x16_m256i a, int16x16_m256i b) {
217*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_cmpgt_epi16(a.v, b.v));
218*5f39d1b3SJooyung Han }
219*5f39d1b3SJooyung Han 
220*5f39d1b3SJooyung Han template <>
221*5f39d1b3SJooyung Han inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
222*5f39d1b3SJooyung Han   return _mm256_cmpgt_epi32(b, a);
223*5f39d1b3SJooyung Han }
224*5f39d1b3SJooyung Han 
225*5f39d1b3SJooyung Han template <>
226*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfLessThan(int16x16_m256i a, int16x16_m256i b) {
227*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_cmpgt_epi16(b.v, a.v));
228*5f39d1b3SJooyung Han }
229*5f39d1b3SJooyung Han 
230*5f39d1b3SJooyung Han template <>
231*5f39d1b3SJooyung Han inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
232*5f39d1b3SJooyung Han   return BitNot(MaskIfLessThan(a, b));
233*5f39d1b3SJooyung Han }
234*5f39d1b3SJooyung Han 
235*5f39d1b3SJooyung Han template <>
236*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfGreaterThanOrEqual(int16x16_m256i a,
237*5f39d1b3SJooyung Han                                                int16x16_m256i b) {
238*5f39d1b3SJooyung Han   return BitNot(MaskIfLessThan(a, b));
239*5f39d1b3SJooyung Han }
240*5f39d1b3SJooyung Han 
241*5f39d1b3SJooyung Han template <>
242*5f39d1b3SJooyung Han inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
243*5f39d1b3SJooyung Han   return BitNot(MaskIfGreaterThan(a, b));
244*5f39d1b3SJooyung Han }
245*5f39d1b3SJooyung Han 
246*5f39d1b3SJooyung Han template <>
247*5f39d1b3SJooyung Han inline int16x16_m256i MaskIfLessThanOrEqual(int16x16_m256i a,
248*5f39d1b3SJooyung Han                                             int16x16_m256i b) {
249*5f39d1b3SJooyung Han   return BitNot(MaskIfGreaterThan(a, b));
250*5f39d1b3SJooyung Han }
251*5f39d1b3SJooyung Han 
252*5f39d1b3SJooyung Han /* Assumptions:
253*5f39d1b3SJooyung Han    - All and Any are used on masks.
254*5f39d1b3SJooyung Han    - masks are all_ones for true lanes, all_zeroes otherwise.
255*5f39d1b3SJooyung Han Hence, All means all 128bits set, and Any means any bit set.
256*5f39d1b3SJooyung Han */
257*5f39d1b3SJooyung Han 
258*5f39d1b3SJooyung Han template <>
259*5f39d1b3SJooyung Han inline bool All(__m256i a) {
260*5f39d1b3SJooyung Han   return _mm256_testc_si256(a, a);
261*5f39d1b3SJooyung Han }
262*5f39d1b3SJooyung Han 
263*5f39d1b3SJooyung Han template <>
264*5f39d1b3SJooyung Han inline bool All(int16x16_m256i a) {
265*5f39d1b3SJooyung Han   return _mm256_testc_si256(a.v, a.v);
266*5f39d1b3SJooyung Han }
267*5f39d1b3SJooyung Han 
268*5f39d1b3SJooyung Han template <>
269*5f39d1b3SJooyung Han inline bool Any(__m256i a) {
270*5f39d1b3SJooyung Han   return BitNot(_mm256_testz_si256(a, a));
271*5f39d1b3SJooyung Han }
272*5f39d1b3SJooyung Han 
273*5f39d1b3SJooyung Han template <>
274*5f39d1b3SJooyung Han inline bool Any(int16x16_m256i a) {
275*5f39d1b3SJooyung Han   return BitNot(_mm256_testz_si256(a.v, a.v));
276*5f39d1b3SJooyung Han }
277*5f39d1b3SJooyung Han 
278*5f39d1b3SJooyung Han template <>
279*5f39d1b3SJooyung Han inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
280*5f39d1b3SJooyung Han   /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
281*5f39d1b3SJooyung Han   /* We divide the inputs before the add to avoid the overflow and costly test
282*5f39d1b3SJooyung Han    */
283*5f39d1b3SJooyung Han   /* of checking if an overflow occured on signed add */
284*5f39d1b3SJooyung Han   /* round_bit_mask = _mm_set1_epi32(1); */
285*5f39d1b3SJooyung Han   /* a_over_2 = _mm_srai_epi32(a, 1); */
286*5f39d1b3SJooyung Han   /* b_over_2 = _mm_srai_epi32(b, 1); */
287*5f39d1b3SJooyung Han   /* sum = Add(a_over_2, b_over_2); */
288*5f39d1b3SJooyung Han   /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
289*5f39d1b3SJooyung Han   /* return Add(sum, round_bit); */
290*5f39d1b3SJooyung Han 
291*5f39d1b3SJooyung Han   /* Other possibility detecting overflow and xor the sign if an overflow
292*5f39d1b3SJooyung Han    * happened*/
293*5f39d1b3SJooyung Han   __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
294*5f39d1b3SJooyung Han   one = _mm256_set1_epi32(1);
295*5f39d1b3SJooyung Han   sign_bit_mask = _mm256_set1_epi32(0x80000000);
296*5f39d1b3SJooyung Han   sum = Add(a, b);
297*5f39d1b3SJooyung Han   rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1);
298*5f39d1b3SJooyung Han   overflow =
299*5f39d1b3SJooyung Han       BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
300*5f39d1b3SJooyung Han              sign_bit_mask);
301*5f39d1b3SJooyung Han   result = BitXor(rounded_half_sum, overflow);
302*5f39d1b3SJooyung Han   return result;
303*5f39d1b3SJooyung Han }
304*5f39d1b3SJooyung Han 
305*5f39d1b3SJooyung Han template <>
306*5f39d1b3SJooyung Han inline int16x16_m256i RoundingHalfSum(int16x16_m256i a, int16x16_m256i b) {
307*5f39d1b3SJooyung Han   // Borrowed from Intel's arm_neon_sse.h header.
308*5f39d1b3SJooyung Han   __m256i constant_neg_32768 = _mm256_set1_epi16(-32768);
309*5f39d1b3SJooyung Han   __m256i a_unsigned = _mm256_sub_epi16(a.v, constant_neg_32768);
310*5f39d1b3SJooyung Han   __m256i b_unsigned = _mm256_sub_epi16(b.v, constant_neg_32768);
311*5f39d1b3SJooyung Han   __m256i avg_unsigned = _mm256_avg_epu16(a_unsigned, b_unsigned);
312*5f39d1b3SJooyung Han   __m256i avg = _mm256_add_epi16(avg_unsigned, constant_neg_32768);
313*5f39d1b3SJooyung Han   return to_int16x16_m256i(avg);
314*5f39d1b3SJooyung Han }
315*5f39d1b3SJooyung Han 
316*5f39d1b3SJooyung Han template <>
317*5f39d1b3SJooyung Han inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
318*5f39d1b3SJooyung Han   __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
319*5f39d1b3SJooyung Han   __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
320*5f39d1b3SJooyung Han   __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
321*5f39d1b3SJooyung Han   __m256i nudge;
322*5f39d1b3SJooyung Han 
323*5f39d1b3SJooyung Han   // saturation only happen if a == b == INT_MIN
324*5f39d1b3SJooyung Han   min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min());
325*5f39d1b3SJooyung Han   saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
326*5f39d1b3SJooyung Han 
327*5f39d1b3SJooyung Han   // a = a0 | a1 | a2 | a3
328*5f39d1b3SJooyung Han   // b = b0 | b1 | b2 | b3
329*5f39d1b3SJooyung Han   a0_a2 = a;
330*5f39d1b3SJooyung Han   a1_a3 = _mm256_srli_si256(a, 4);
331*5f39d1b3SJooyung Han   b0_b2 = b;
332*5f39d1b3SJooyung Han   b1_b3 = _mm256_srli_si256(b, 4);
333*5f39d1b3SJooyung Han 
334*5f39d1b3SJooyung Han   a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2);
335*5f39d1b3SJooyung Han   a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3);
336*5f39d1b3SJooyung Han 
337*5f39d1b3SJooyung Han   // do the rounding and take into account that it will be doubled
338*5f39d1b3SJooyung Han   nudge = _mm256_set1_epi64x(1 << 30);
339*5f39d1b3SJooyung Han   a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge);
340*5f39d1b3SJooyung Han   a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge);
341*5f39d1b3SJooyung Han 
342*5f39d1b3SJooyung Han   // do the doubling
343*5f39d1b3SJooyung Han   a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1);
344*5f39d1b3SJooyung Han   a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1);
345*5f39d1b3SJooyung Han 
346*5f39d1b3SJooyung Han   // get the high part of the products
347*5f39d1b3SJooyung Han   result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4),
348*5f39d1b3SJooyung Han                               a1b1_a3b3_rounded_2x, 0xcc);
349*5f39d1b3SJooyung Han 
350*5f39d1b3SJooyung Han   // saturate those which overflowed
351*5f39d1b3SJooyung Han   return SelectUsingMask(saturation_mask, min, result);
352*5f39d1b3SJooyung Han }
353*5f39d1b3SJooyung Han 
354*5f39d1b3SJooyung Han template <>
355*5f39d1b3SJooyung Han inline int16x16_m256i SaturatingRoundingDoublingHighMul(int16x16_m256i a,
356*5f39d1b3SJooyung Han                                                         int16x16_m256i b) {
357*5f39d1b3SJooyung Han   // Use _mm256_mulhrs_epi16 then saturate with a bit-operation,
358*5f39d1b3SJooyung Han   // borrowed from Intel's arm_neon_sse.h header.
359*5f39d1b3SJooyung Han   __m256i result_unsaturated = _mm256_mulhrs_epi16(a.v, b.v);
360*5f39d1b3SJooyung Han   __m256i saturation_mask =
361*5f39d1b3SJooyung Han       _mm256_cmpeq_epi16(result_unsaturated, _mm256_set1_epi16(0x8000));
362*5f39d1b3SJooyung Han   __m256i result = _mm256_xor_si256(result_unsaturated, saturation_mask);
363*5f39d1b3SJooyung Han   return to_int16x16_m256i(result);
364*5f39d1b3SJooyung Han }
365*5f39d1b3SJooyung Han 
366*5f39d1b3SJooyung Han template <>
367*5f39d1b3SJooyung Han inline __m256i Dup<__m256i>(std::int32_t x) {
368*5f39d1b3SJooyung Han   return _mm256_set1_epi32(x);
369*5f39d1b3SJooyung Han }
370*5f39d1b3SJooyung Han 
371*5f39d1b3SJooyung Han template <>
372*5f39d1b3SJooyung Han inline int16x16_m256i Dup<int16x16_m256i>(std::int16_t x) {
373*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_set1_epi16(x));
374*5f39d1b3SJooyung Han }
375*5f39d1b3SJooyung Han 
376*5f39d1b3SJooyung Han // So far this is only needed for int16.
377*5f39d1b3SJooyung Han template <>
378*5f39d1b3SJooyung Han inline int16x16_m256i SaturatingAdd(int16x16_m256i a, int16x16_m256i b) {
379*5f39d1b3SJooyung Han   return to_int16x16_m256i(_mm256_adds_epi16(a.v, b.v));
380*5f39d1b3SJooyung Han }
381*5f39d1b3SJooyung Han 
382*5f39d1b3SJooyung Han }  // end namespace gemmlowp
383*5f39d1b3SJooyung Han 
384*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
385