xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x4c2-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <immintrin.h>
10 
11 #include <qnnpack/q8conv.h>
12 #include <requantization/runtime-sse2.h>
13 
pytorch_q8conv_ukernel_4x4c2__sse2(size_t mr,size_t nr,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t c_stride,size_t output_channel_index,const union pytorch_qnnp_conv_quantization_params quantization_params[RESTRICT_STATIC1])14 void pytorch_q8conv_ukernel_4x4c2__sse2(
15     size_t mr,
16     size_t nr,
17     size_t kc,
18     size_t ks,
19     const uint8_t** restrict a,
20     const void* restrict w,
21     uint8_t* restrict c,
22     size_t c_stride,
23     size_t output_channel_index,
24     const union pytorch_qnnp_conv_quantization_params
25         quantization_params[RESTRICT_STATIC 1]) {
26   __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w);
27   __m128i vacc1x0123 = vacc0x0123;
28   __m128i vacc2x0123 = vacc0x0123;
29   __m128i vacc3x0123 = vacc0x0123;
30   w = (const void*)((uintptr_t)w + 16);
31 
32   const __m128i va_zero_point = _mm_load_si128(
33       (const __m128i*)quantization_params->sse2.input_zero_point);
34   const int16_t vb_zero_point_0 =
35     quantization_params->sse2.kernel_zero_points[output_channel_index];
36   const int16_t vb_zero_point_1 =
37       quantization_params->sse2.kernel_zero_points[output_channel_index + 1];
38   const int16_t vb_zero_point_2 =
39       quantization_params->sse2.kernel_zero_points[output_channel_index + 2];
40   const int16_t vb_zero_point_3 =
41       quantization_params->sse2.kernel_zero_points[output_channel_index + 3];
42 
43   const __m128i vb_zero_point = _mm_set_epi16(vb_zero_point_3,
44                                               vb_zero_point_3,
45                                               vb_zero_point_2,
46                                               vb_zero_point_2,
47                                               vb_zero_point_1,
48                                               vb_zero_point_1,
49                                               vb_zero_point_0,
50                                               vb_zero_point_0
51                                               );
52   const __m128i vzero = _mm_setzero_si128();
53   do {
54     const uint8_t* restrict a0 = *a++;
55     const uint8_t* restrict a1 = *a++;
56     const uint8_t* restrict a2 = *a++;
57     const uint8_t* restrict a3 = *a++;
58 
59     size_t k = kc;
60     for (; k >= 8; k -= 8) {
61       const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
62       const __m128i vxa0 =
63           sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
64       a0 += 8;
65       const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
66       const __m128i vxa1 =
67           sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
68       a1 += 8;
69       const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2);
70       const __m128i vxa2 =
71           sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
72       a2 += 8;
73       const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3);
74       const __m128i vxa3 =
75           sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
76       a3 += 8;
77 
78       const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
79       const __m128i vxb0 =
80           _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
81       vacc0x0123 = _mm_add_epi32(
82           vacc0x0123,
83           _mm_madd_epi16(
84               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
85       vacc1x0123 = _mm_add_epi32(
86           vacc1x0123,
87           _mm_madd_epi16(
88               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
89       vacc2x0123 = _mm_add_epi32(
90           vacc2x0123,
91           _mm_madd_epi16(
92               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
93       vacc3x0123 = _mm_add_epi32(
94           vacc3x0123,
95           _mm_madd_epi16(
96               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
97 
98       const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
99       const __m128i vxb1 =
100           _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
101       vacc0x0123 = _mm_add_epi32(
102           vacc0x0123,
103           _mm_madd_epi16(
104               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
105       vacc1x0123 = _mm_add_epi32(
106           vacc1x0123,
107           _mm_madd_epi16(
108               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
109       vacc2x0123 = _mm_add_epi32(
110           vacc2x0123,
111           _mm_madd_epi16(
112               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
113       vacc3x0123 = _mm_add_epi32(
114           vacc3x0123,
115           _mm_madd_epi16(
116               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
117 
118       const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
119       const __m128i vxb2 =
120           _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
121       vacc0x0123 = _mm_add_epi32(
122           vacc0x0123,
123           _mm_madd_epi16(
124               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
125       vacc1x0123 = _mm_add_epi32(
126           vacc1x0123,
127           _mm_madd_epi16(
128               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
129       vacc2x0123 = _mm_add_epi32(
130           vacc2x0123,
131           _mm_madd_epi16(
132               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
133       vacc3x0123 = _mm_add_epi32(
134           vacc3x0123,
135           _mm_madd_epi16(
136               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
137 
138       const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
139       const __m128i vxb3 =
140           _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
141       vacc0x0123 = _mm_add_epi32(
142           vacc0x0123,
143           _mm_madd_epi16(
144               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
145       vacc1x0123 = _mm_add_epi32(
146           vacc1x0123,
147           _mm_madd_epi16(
148               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
149       vacc2x0123 = _mm_add_epi32(
150           vacc2x0123,
151           _mm_madd_epi16(
152               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
153       vacc3x0123 = _mm_add_epi32(
154           vacc3x0123,
155           _mm_madd_epi16(
156               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
157 
158       w = (void*)((uintptr_t)w + 32);
159     }
160     if (k != 0) {
161       const size_t a_predecrement = 8 - k;
162       const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
163 
164       const __m128i va0 = _mm_srl_epi64(
165           _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
166       const __m128i vxa0 =
167           sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
168       const __m128i va1 = _mm_srl_epi64(
169           _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
170       const __m128i vxa1 =
171           sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
172       const __m128i va2 = _mm_srl_epi64(
173           _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift);
174       const __m128i vxa2 =
175           sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
176       const __m128i va3 = _mm_srl_epi64(
177           _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift);
178       const __m128i vxa3 =
179           sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
180 
181       const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
182       const __m128i vxb0 =
183           _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
184       w = (void*)((uintptr_t)w + 8);
185 
186       vacc0x0123 = _mm_add_epi32(
187           vacc0x0123,
188           _mm_madd_epi16(
189               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
190       vacc1x0123 = _mm_add_epi32(
191           vacc1x0123,
192           _mm_madd_epi16(
193               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
194       vacc2x0123 = _mm_add_epi32(
195           vacc2x0123,
196           _mm_madd_epi16(
197               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
198       vacc3x0123 = _mm_add_epi32(
199           vacc3x0123,
200           _mm_madd_epi16(
201               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
202 
203       if (k > 2) {
204         const __m128i vb1 = _mm_loadl_epi64((const __m128i*)w);
205         const __m128i vxb1 =
206             _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
207         w = (void*)((uintptr_t)w + 8);
208 
209         vacc0x0123 = _mm_add_epi32(
210             vacc0x0123,
211             _mm_madd_epi16(
212                 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
213         vacc1x0123 = _mm_add_epi32(
214             vacc1x0123,
215             _mm_madd_epi16(
216                 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
217         vacc2x0123 = _mm_add_epi32(
218             vacc2x0123,
219             _mm_madd_epi16(
220                 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
221         vacc3x0123 = _mm_add_epi32(
222             vacc3x0123,
223             _mm_madd_epi16(
224                 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
225 
226         if (k > 4) {
227           const __m128i vb2 = _mm_loadl_epi64((const __m128i*)w);
228           const __m128i vxb2 =
229               _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
230           w = (void*)((uintptr_t)w + 8);
231 
232           vacc0x0123 = _mm_add_epi32(
233               vacc0x0123,
234               _mm_madd_epi16(
235                   _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
236           vacc1x0123 = _mm_add_epi32(
237               vacc1x0123,
238               _mm_madd_epi16(
239                   _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
240           vacc2x0123 = _mm_add_epi32(
241               vacc2x0123,
242               _mm_madd_epi16(
243                   _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
244           vacc3x0123 = _mm_add_epi32(
245               vacc3x0123,
246               _mm_madd_epi16(
247                   _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
248 
249           if (k > 6) {
250             const __m128i vb3 = _mm_loadl_epi64((const __m128i*)w);
251             const __m128i vxb3 =
252                 _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
253             w = (void*)((uintptr_t)w + 8);
254 
255             vacc0x0123 = _mm_add_epi32(
256                 vacc0x0123,
257                 _mm_madd_epi16(
258                     _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
259             vacc1x0123 = _mm_add_epi32(
260                 vacc1x0123,
261                 _mm_madd_epi16(
262                     _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
263             vacc2x0123 = _mm_add_epi32(
264                 vacc2x0123,
265                 _mm_madd_epi16(
266                     _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
267             vacc3x0123 = _mm_add_epi32(
268                 vacc3x0123,
269                 _mm_madd_epi16(
270                     _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
271           }
272         }
273       }
274     }
275   } while (--ks != 0);
276 
277   const __m128 vmultiplier =
278       _mm_loadu_ps(&quantization_params->sse2.requantization_scales
279           [output_channel_index]);
280 
281   vacc0x0123 = _mm_cvtps_epi32(
282                 _mm_mul_ps(
283                   _mm_cvtepi32_ps(vacc0x0123),
284                   vmultiplier
285                   )
286                 );
287   vacc1x0123 = _mm_cvtps_epi32(
288                 _mm_mul_ps(
289                   _mm_cvtepi32_ps(vacc1x0123),
290                   vmultiplier
291                   )
292                 );
293   vacc2x0123 = _mm_cvtps_epi32(
294                 _mm_mul_ps(
295                   _mm_cvtepi32_ps(vacc2x0123),
296                   vmultiplier
297                   )
298                 );
299   vacc3x0123 = _mm_cvtps_epi32(
300                 _mm_mul_ps(
301                   _mm_cvtepi32_ps(vacc3x0123),
302                   vmultiplier
303                   )
304                 );
305 
306   const __m128i voutput_zero_point = _mm_load_si128(
307       (const __m128i*)quantization_params->sse2.output_zero_point);
308   const __m128i vacc01x0123 = _mm_adds_epi16(
309       _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
310   const __m128i vacc23x0123 = _mm_adds_epi16(
311       _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
312   __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
313   vout = _mm_min_epu8(
314       vout,
315       _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
316   vout = _mm_max_epu8(
317       vout,
318       _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
319 
320   uint8_t* c0 = c;
321   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
322   if (mr < 2) {
323     c1 = c0;
324   }
325   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
326   if (mr <= 2) {
327     c2 = c1;
328   }
329   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
330   if (mr != 4) {
331     c3 = c2;
332   }
333   if (nr == 4) {
334     *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout);
335     *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
336     *((uint32_t*)c2) =
337         (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
338     *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
339   } else {
340     typedef PYTORCH_QNNP_UNALIGNED uint16_t unaligned_uint16_t;
341     if (nr >= 2) {
342       *((unaligned_uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0);
343       c0 += 2;
344       *((unaligned_uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2);
345       c1 += 2;
346       *((unaligned_uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4);
347       c2 += 2;
348       *((unaligned_uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6);
349       c3 += 2;
350       vout = _mm_srli_epi32(vout, 16);
351       nr -= 2;
352     }
353     if (nr != 0) {
354       *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout);
355       *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2);
356       *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4);
357       *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6);
358     }
359   }
360 }
361