xref: /aosp_15_r20/external/gemmlowp/internal/simd_wrappers_neon.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // simd_wrappers_neon.h: NEON specialization of simd_wrappers.h
16 
17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
19 
20 #include <arm_neon.h>
21 
22 namespace gemmlowp {
23 
24 using Int32x4 = int32x4_t;
25 using Int16x4 = int16x4_t;
26 using Int16x8 = int16x8_t;
27 using Uint8x8 = uint8x8_t;
28 using Int8x8 = int8x8_t;
29 
30 template <int ScalarCount>
31 struct RegisterType<std::int32_t, ScalarCount> {
32   using Type =
33       typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
34 };
35 
36 template <int ScalarCount>
37 struct RegisterType<std::int16_t, ScalarCount> {
38   using Type = typename std::conditional<
39       ScalarCount >= 8, Int16x8,
40       typename std::conditional<ScalarCount >= 4, Int16x4,
41                                 std::int16_t>::type>::type;
42 };
43 
44 template <int ScalarCount>
45 struct RegisterType<std::uint8_t, ScalarCount> {
46   using Type = typename std::conditional<
47       ScalarCount >= 8, Uint8x8,
48       typename std::conditional<ScalarCount >= 4, std::uint32_t,
49                                 std::uint8_t>::type>::type;
50 };
51 
52 template <int ScalarCount>
53 struct RegisterType<std::int8_t, ScalarCount> {
54   using Type = typename std::conditional<
55       ScalarCount >= 8, Int8x8,
56       typename std::conditional<ScalarCount >= 4, std::int32_t,
57                                 std::int8_t>::type>::type;
58 };
59 
60 inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
61 inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
62 inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
63 
64 inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
65   vst1q_s32(dst, value);
66 }
67 
68 inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
69   vst1_s16(dst, value);
70 }
71 
72 inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
73   vst1q_s16(dst, value);
74 }
75 
76 template <int Lane>
77 std::int32_t GetLane(Int32x4 value) {
78   return vgetq_lane_s32(value, Lane);
79 }
80 
81 template <int Lane>
82 Int32x4 DupLane(Int32x4 value) {
83   switch (Lane) {
84     case 0:
85       return vdupq_lane_s32(vget_low_s32(value), 0);
86     case 1:
87       return vdupq_lane_s32(vget_low_s32(value), 1);
88     case 2:
89       return vdupq_lane_s32(vget_high_s32(value), 0);
90     case 3:
91       return vdupq_lane_s32(vget_high_s32(value), 1);
92     default:
93       static_assert(Lane >= 0 && Lane <= 3, "");
94       return vdupq_n_s32(0);
95   }
96 }
97 
98 inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); }
99 
100 inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); }
101 
102 inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
103 
104 inline Int32x4 Max(Int32x4 a, std::int32_t b) {
105   return vmaxq_s32(a, vdupq_n_s32(b));
106 }
107 
108 inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
109   return vqrdmulhq_n_s32(a, b);
110 }
111 
112 template <int Lane>
113 Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
114   switch (Lane) {
115     case 0:
116       return vmulq_lane_s32(a, vget_low_s32(b), 0);
117     case 1:
118       return vmulq_lane_s32(a, vget_low_s32(b), 1);
119     case 2:
120       return vmulq_lane_s32(a, vget_high_s32(b), 0);
121     case 3:
122       return vmulq_lane_s32(a, vget_high_s32(b), 1);
123     default:
124       static_assert(Lane >= 0 && Lane <= 3, "");
125       return vdupq_n_s32(0);
126   }
127 }
128 
129 inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
130   *acc = vmlaq_s32(*acc, lhs, rhs);
131 }
132 
133 inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
134   *acc = vmlaq_n_s32(*acc, lhs, rhs);
135 }
136 
137 template <int Lane>
138 inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
139   switch (Lane) {
140     case 0:
141       *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0);
142       break;
143     case 1:
144       *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1);
145       break;
146     case 2:
147       *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0);
148       break;
149     case 3:
150       *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1);
151       break;
152     default:
153       static_assert(Lane >= 0 && Lane <= 3, "");
154   }
155 }
156 
157 template <>
158 struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
159   static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
160     RegBlockInt16<8, 8> result;
161     for (int i = 0; i < 8; i++) {
162       result.buf.reg[i] = vld1q_s16(src + 8 * i);
163     }
164     return result;
165   }
166 };
167 
168 template <>
169 struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
170   static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
171     RegBlockUint8<8, 8> result;
172     for (int i = 0; i < 8; i++) {
173       result.buf.reg[i] = vld1_u8(src + 8 * i);
174     }
175     return result;
176   }
177 };
178 
179 template <>
180 struct LoadContiguousImpl<RegBlockInt8<8, 8>> {
181   static RegBlockInt8<8, 8> Run(const std::int8_t* src) {
182     RegBlockInt8<8, 8> result;
183     for (int i = 0; i < 8; i++) {
184       result.buf.reg[i] = vld1_s8(src + 8 * i);
185     }
186     return result;
187   }
188 };
189 
190 template <>
191 struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
192   static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
193     RegBlockInt32<8, 8> result;
194     for (int i = 0; i < 16; i++) {
195       result.buf.reg[i] = vld1q_s32(src + 4 * i);
196     }
197     return result;
198   }
199 };
200 
201 // 4x1 := 4x1 + 1x1
202 template <>
203 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
204   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
205                                  const RegBlockInt32<1, 1>& rhs) {
206     RegBlockInt32<4, 1> result;
207     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
208     return result;
209   }
210 };
211 
212 // 1x4 := 1x4 + 1x1
213 template <>
214 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
215   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
216                                  const RegBlockInt32<1, 1>& rhs) {
217     RegBlockInt32<1, 4> result;
218     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
219     return result;
220   }
221 };
222 
223 // 4x1 := 4x1 + 4x1
224 template <>
225 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
226   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
227                                  const RegBlockInt32<4, 1>& rhs) {
228     RegBlockInt32<4, 1> result;
229     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
230     return result;
231   }
232 };
233 
234 // 1x4 := 1x4 + 1x4
235 template <>
236 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
237   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
238                                  const RegBlockInt32<1, 4>& rhs) {
239     RegBlockInt32<1, 4> result;
240     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
241     return result;
242   }
243 };
244 
245 // 4x4 := 4x4 + 1x4
246 template <>
247 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
248   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
249                                  const RegBlockInt32<1, 4>& rhs) {
250     RegBlockInt32<4, 4> result;
251     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
252     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
253     result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
254     result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
255     return result;
256   }
257 };
258 
259 // 4x4 := 4x4 + 4x1
260 template <>
261 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
262   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
263                                  const RegBlockInt32<4, 1>& rhs) {
264     RegBlockInt32<4, 4> result;
265     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
266     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]);
267     result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
268     result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]);
269     return result;
270   }
271 };
272 
273 // 8x1 := 8x1 + 1x1
274 template <>
275 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
276   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
277                                  const RegBlockInt32<1, 1>& rhs) {
278     RegBlockInt32<8, 1> result;
279     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
280     for (int i = 0; i < 2; i++) {
281       result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p);
282     }
283     return result;
284   }
285 };
286 
287 // 8x1 := 8x1 + 8x1
288 template <>
289 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
290   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
291                                  const RegBlockInt32<8, 1>& rhs) {
292     RegBlockInt32<8, 1> result;
293     for (int i = 0; i < 2; i++) {
294       result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]);
295     }
296     return result;
297   }
298 };
299 
300 // 8x4 := 8x4 + 1x4
301 template <>
302 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
303   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
304                                  const RegBlockInt32<1, 4>& rhs) {
305     RegBlockInt32<8, 4> result;
306     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
307     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
308     result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
309     result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
310     result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
311     result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
312     result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
313     result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
314     return result;
315   }
316 };
317 
318 // 8x4 := 8x4 + 8x1
319 template <>
320 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
321   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
322                                  const RegBlockInt32<8, 1>& rhs) {
323     RegBlockInt32<8, 4> result;
324     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
325     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
326     result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
327     result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]);
328     result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]);
329     result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]);
330     result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]);
331     result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]);
332     return result;
333   }
334 };
335 
336 // 1x8 := 1x8 + 1x8
337 template <>
338 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
339   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
340                                  const RegBlockInt32<1, 8>& rhs) {
341     RegBlockInt32<1, 8> result;
342     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
343     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
344     return result;
345   }
346 };
347 
348 // 1x8 := 1x8 + 1x1
349 template <>
350 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
351   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
352                                  const RegBlockInt32<1, 1>& rhs) {
353     RegBlockInt32<1, 8> result;
354     result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
355     result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
356     return result;
357   }
358 };
359 
360 // 4x1 := 4x1 + 1x1
361 template <>
362 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
363                                         RegBlockInt32<1, 1>> {
364   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
365                                  const RegBlockInt32<1, 1>& rhs) {
366     RegBlockInt32<4, 1> result;
367     result.buf.reg[0] =
368         RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
369     return result;
370   }
371 };
372 
373 // 1x4 := 1x4 + 1x1
374 template <>
375 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
376                                         RegBlockInt32<1, 1>> {
377   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
378                                  const RegBlockInt32<1, 1>& rhs) {
379     RegBlockInt32<1, 4> result;
380     result.buf.reg[0] =
381         RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
382     return result;
383   }
384 };
385 
386 // 4x1 := 4x1 + 4x1
387 template <>
388 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
389                                         RegBlockInt32<4, 1>> {
390   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
391                                  const RegBlockInt32<4, 1>& rhs) {
392     RegBlockInt32<4, 1> result;
393     result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
394     return result;
395   }
396 };
397 
398 // 1x4 := 1x4 + 1x4
399 template <>
400 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
401                                         RegBlockInt32<1, 4>> {
402   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
403                                  const RegBlockInt32<1, 4>& rhs) {
404     RegBlockInt32<1, 4> result;
405     result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
406     return result;
407   }
408 };
409 
410 // 4x4 := 4x4 + 1x4
411 template <>
412 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
413                                         RegBlockInt32<1, 4>> {
414   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
415                                  const RegBlockInt32<1, 4>& rhs) {
416     RegBlockInt32<4, 4> result;
417     result.buf.reg[0] =
418         RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
419     result.buf.reg[1] =
420         RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
421     result.buf.reg[2] =
422         RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
423     result.buf.reg[3] =
424         RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
425     return result;
426   }
427 };
428 
429 // 4x4 := 4x4 + 4x1
430 template <>
431 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
432                                         RegBlockInt32<4, 1>> {
433   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
434                                  const RegBlockInt32<4, 1>& rhs) {
435     RegBlockInt32<4, 4> result;
436     result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
437     result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]);
438     result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
439     result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]);
440     return result;
441   }
442 };
443 
444 // 8x1 := 8x1 + 1x1
445 template <>
446 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
447                                         RegBlockInt32<1, 1>> {
448   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
449                                  const RegBlockInt32<1, 1>& rhs) {
450     RegBlockInt32<8, 1> result;
451     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
452     for (int i = 0; i < 2; i++) {
453       result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p);
454     }
455     return result;
456   }
457 };
458 
459 // 8x1 := 8x1 + 8x1
460 template <>
461 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
462                                         RegBlockInt32<8, 1>> {
463   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
464                                  const RegBlockInt32<8, 1>& rhs) {
465     RegBlockInt32<8, 1> result;
466     for (int i = 0; i < 2; i++) {
467       result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]);
468     }
469     return result;
470   }
471 };
472 
473 // 8x4 := 8x4 + 1x4
474 template <>
475 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
476                                         RegBlockInt32<1, 4>> {
477   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
478                                  const RegBlockInt32<1, 4>& rhs) {
479     RegBlockInt32<8, 4> result;
480     result.buf.reg[0] =
481         RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
482     result.buf.reg[1] =
483         RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
484     result.buf.reg[2] =
485         RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
486     result.buf.reg[3] =
487         RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
488     result.buf.reg[4] =
489         RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
490     result.buf.reg[5] =
491         RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
492     result.buf.reg[6] =
493         RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
494     result.buf.reg[7] =
495         RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
496     return result;
497   }
498 };
499 
500 // 8x4 := 8x4 + 8x1
501 template <>
502 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
503                                         RegBlockInt32<8, 1>> {
504   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
505                                  const RegBlockInt32<8, 1>& rhs) {
506     RegBlockInt32<8, 4> result;
507     result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
508     result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
509     result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
510     result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]);
511     result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]);
512     result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]);
513     result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]);
514     result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]);
515     return result;
516   }
517 };
518 
519 // 1x8 := 1x8 + 1x8
520 template <>
521 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
522                                         RegBlockInt32<1, 8>> {
523   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
524                                  const RegBlockInt32<1, 8>& rhs) {
525     RegBlockInt32<1, 8> result;
526     result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
527     result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
528     return result;
529   }
530 };
531 
532 // 1x8 := 1x8 + 1x1
533 template <>
534 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
535                                         RegBlockInt32<1, 1>> {
536   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
537                                  const RegBlockInt32<1, 1>& rhs) {
538     RegBlockInt32<1, 8> result;
539     result.buf.reg[0] =
540         RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
541     result.buf.reg[1] =
542         RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
543     return result;
544   }
545 };
546 
547 }  // end namespace gemmlowp
548 
549 #include "simd_wrappers_common_neon_sse.h"
550 
551 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
552