xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/q8gavgpool.h>
14 
pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,int32_t * buffer,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2(
16     size_t m,
17     size_t n,
18     const uint8_t* input,
19     size_t input_stride,
20     const uint8_t* zero,
21     int32_t* buffer,
22     uint8_t* output,
23     const union pytorch_qnnp_avgpool_quantization_params
24         quantization_params[RESTRICT_STATIC 1]) {
25   assert(m > 7);
26   assert(n >= 8);
27 
28   const uint8_t* i0 = input;
29   const uint8_t* i1 = i0 + input_stride;
30   const uint8_t* i2 = i1 + input_stride;
31   const uint8_t* i3 = i2 + input_stride;
32   const uint8_t* i4 = i3 + input_stride;
33   const uint8_t* i5 = i4 + input_stride;
34   const uint8_t* i6 = i5 + input_stride;
35   const size_t packed_n = (n + 7) & -8;
36   const size_t input_increment = 7 * input_stride - packed_n;
37   const __m128i vbias =
38       _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
39   const __m128i vzero = _mm_setzero_si128();
40 
41   /* note: goes up to 7 elements over bound */
42   int32_t* acc = buffer;
43   for (size_t k = 0; k < n; k += 8) {
44     const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
45     i0 += 8;
46     const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
47     i1 += 8;
48     const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
49     i2 += 8;
50     const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
51     i3 += 8;
52     const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
53     i4 += 8;
54     const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
55     i5 += 8;
56     const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
57     i6 += 8;
58 
59     const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
60     const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
61     const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
62     const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
63     const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
64     const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
65     const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
66 
67     __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
68     __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
69     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
70     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
71     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
72     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
73     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
74     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
75     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
76     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
77     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
78     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
79     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
80     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
81 
82     _mm_store_si128((__m128i*)acc, vacc_lo);
83     _mm_store_si128((__m128i*)acc + 1, vacc_hi);
84     acc += 8;
85   }
86   for (m -= 7; m > 7; m -= 7) {
87     acc = buffer;
88     i0 = (const uint8_t*)((uintptr_t)i0 + input_increment);
89     i1 = (const uint8_t*)((uintptr_t)i1 + input_increment);
90     i2 = (const uint8_t*)((uintptr_t)i2 + input_increment);
91     i3 = (const uint8_t*)((uintptr_t)i3 + input_increment);
92     i4 = (const uint8_t*)((uintptr_t)i4 + input_increment);
93     i5 = (const uint8_t*)((uintptr_t)i5 + input_increment);
94     i6 = (const uint8_t*)((uintptr_t)i6 + input_increment);
95 
96     /* note: goes up to 7 elements over bound */
97     for (size_t k = 0; k < n; k += 8) {
98       const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
99       i0 += 8;
100       const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
101       i1 += 8;
102       const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
103       i2 += 8;
104       const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
105       i3 += 8;
106       const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
107       i4 += 8;
108       const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
109       i5 += 8;
110       const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
111       i6 += 8;
112       __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
113       __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
114 
115       const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
116       const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
117       const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
118       const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
119       const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
120       const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
121       const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
122 
123       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
124       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
125       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
126       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
127       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
128       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
129       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
130       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
131       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
132       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
133       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
134       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
135       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
136       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
137 
138       _mm_store_si128((__m128i*)acc, vacc_lo);
139       _mm_store_si128((__m128i*)acc + 1, vacc_hi);
140       acc += 8;
141     }
142   }
143 
144   const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
145 
146   i0 = (const uint8_t*)((uintptr_t)i0 + input_increment);
147   i1 = (const uint8_t*)((uintptr_t)i1 + input_increment);
148   if (m < 2) {
149     i1 = zero;
150   }
151   i2 = (const uint8_t*)((uintptr_t)i2 + input_increment);
152   if (m <= 2) {
153     i2 = zero;
154   }
155   i3 = (const uint8_t*)((uintptr_t)i3 + input_increment);
156   if (m < 4) {
157     i3 = zero;
158   }
159   i4 = (const uint8_t*)((uintptr_t)i4 + input_increment);
160   if (m <= 4) {
161     i4 = zero;
162   }
163   i5 = (const uint8_t*)((uintptr_t)i5 + input_increment);
164   if (m < 6) {
165     i5 = zero;
166   }
167   i6 = (const uint8_t*)((uintptr_t)i6 + input_increment);
168   if (m <= 6) {
169     i6 = zero;
170   }
171 
172   acc = buffer;
173   do {
174     const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
175     i0 += 8;
176     const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
177     i1 += 8;
178     const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
179     i2 += 8;
180     const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
181     i3 += 8;
182     const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
183     i4 += 8;
184     const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
185     i5 += 8;
186     const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
187     i6 += 8;
188     __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
189     __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
190     acc += 8;
191 
192     const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
193     const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
194     const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
195     const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
196     const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
197     const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
198     const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
199 
200     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
201     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
202     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
203     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
204     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
205     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
206     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
207     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
208     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
209     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
210     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
211     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
212     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
213     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
214 
215     const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
216     const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
217 
218     const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
219     const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
220 
221     __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
222     vout = _mm_adds_epi16(
223         vout,
224         _mm_load_si128(
225             (const __m128i*)quantization_params->sse2.output_zero_point));
226     vout = _mm_packus_epi16(vout, vout);
227     vout = _mm_min_epu8(
228         vout,
229         _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
230     vout = _mm_max_epu8(
231         vout,
232         _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
233 
234     _mm_storel_epi64((__m128i*)output, vout);
235     output += 8;
236 
237     n -= 8;
238   } while (n >= 8);
239   if (n != 0) {
240     const size_t address_decrement = 8 - n;
241     i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
242     i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
243     i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
244     i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
245     i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
246     i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
247     i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
248     const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement);
249 
250     const __m128i vi0 =
251         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
252     const __m128i vi1 =
253         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
254     const __m128i vi2 =
255         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
256     const __m128i vi3 =
257         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
258     const __m128i vi4 =
259         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
260     const __m128i vi5 =
261         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
262     const __m128i vi6 =
263         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
264     __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
265     __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
266 
267     const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
268     const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
269     const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
270     const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
271     const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
272     const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
273     const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
274 
275     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
276     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
277     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
278     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
279     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
280     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
281     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
282     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
283     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
284     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
285     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
286     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
287     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
288     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
289 
290     const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
291     const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
292 
293     const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
294     const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
295 
296     __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
297     vout = _mm_adds_epi16(
298         vout,
299         _mm_load_si128(
300             (const __m128i*)quantization_params->sse2.output_zero_point));
301     vout = _mm_packus_epi16(vout, vout);
302     vout = _mm_min_epu8(
303         vout,
304         _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
305     vout = _mm_max_epu8(
306         vout,
307         _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
308 
309     if (n & 4) {
310       *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
311       output += 4;
312       vout = _mm_srli_epi64(vout, 32);
313     }
314     if (n & 2) {
315       *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
316       output += 2;
317       vout = _mm_srli_epi32(vout, 16);
318     }
319     if (n & 1) {
320       *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
321     }
322   }
323 }
324