xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <immintrin.h>
10 
11 #include <qnnpack/q8dwconv.h>
12 #include <requantization/runtime-sse2.h>
13 
pytorch_q8dwconv_ukernel_up8x9__sse2(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[RESTRICT_STATIC1])14 void pytorch_q8dwconv_ukernel_up8x9__sse2(
15     size_t channels,
16     size_t output_width,
17     const uint8_t** input,
18     const void* weights,
19     uint8_t* output,
20     size_t input_stride,
21     size_t output_increment,
22     const union pytorch_qnnp_conv_quantization_params
23         quantization_params[RESTRICT_STATIC 1]) {
24   const __m128i va_zero_point = _mm_load_si128(
25       (const __m128i*)quantization_params->sse2.input_zero_point);
26   const __m128i vkernel_zero_point = _mm_set1_epi16(
27       quantization_params->sse2.kernel_zero_points[0]);
28   const __m128i vzero = _mm_setzero_si128();
29 
30   do {
31     const uint8_t* i0 = input[0];
32     const uint8_t* i1 = input[1];
33     const uint8_t* i2 = input[2];
34     const uint8_t* i3 = input[3];
35     const uint8_t* i4 = input[4];
36     const uint8_t* i5 = input[5];
37     const uint8_t* i6 = input[6];
38     const uint8_t* i7 = input[7];
39     const uint8_t* i8 = input[8];
40 
41     input = (const uint8_t**)((uintptr_t)input + input_stride);
42 
43     size_t c = channels;
44     const void* w = weights;
45     for (; c >= 8; c -= 8) {
46       __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
47       __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
48 
49       const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
50       i0 += 8;
51       const __m128i vxi0 =
52           sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
53       const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
54       const __m128i vxk0 =
55           _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
56       const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
57       const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
58       vacc_lo =
59           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
60       vacc_hi =
61           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
62 
63       const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
64       i1 += 8;
65       const __m128i vxi1 =
66           sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
67       const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
68       const __m128i vxk1 =
69           _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
70       const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
71       const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
72       vacc_lo =
73           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
74       vacc_hi =
75           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
76 
77       const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
78       i2 += 8;
79       const __m128i vxi2 =
80           sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
81       const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
82       const __m128i vxk2 =
83           _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
84       const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
85       const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
86       vacc_lo =
87           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
88       vacc_hi =
89           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
90 
91       const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
92       i3 += 8;
93       const __m128i vxi3 =
94           sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
95       const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
96       const __m128i vxk3 =
97           _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
98       const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
99       const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
100       vacc_lo =
101           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
102       vacc_hi =
103           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
104 
105       const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
106       i4 += 8;
107       const __m128i vxi4 =
108           sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
109       const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
110       const __m128i vxk4 =
111           _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
112       const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
113       const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
114       vacc_lo =
115           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
116       vacc_hi =
117           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
118 
119       const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
120       i5 += 8;
121       const __m128i vxi5 =
122           sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
123       const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
124       const __m128i vxk5 =
125           _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
126       const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
127       const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
128       vacc_lo =
129           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
130       vacc_hi =
131           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
132 
133       const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
134       i6 += 8;
135       const __m128i vxi6 =
136           sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
137       const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
138       const __m128i vxk6 =
139           _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
140       const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
141       const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
142       vacc_lo =
143           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
144       vacc_hi =
145           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
146 
147       const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7);
148       i7 += 8;
149       const __m128i vxi7 =
150           sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
151       const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
152       const __m128i vxk7 =
153           _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
154       const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
155       const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
156       vacc_lo =
157           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
158       vacc_hi =
159           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
160 
161       const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8);
162       i8 += 8;
163       const __m128i vxi8 =
164           sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
165       const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
166       const __m128i vxk8 =
167           _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
168       const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
169       const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
170       vacc_lo =
171           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
172       vacc_hi =
173           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
174 
175       w = (void*)((uintptr_t)w + 104);
176 
177       const __m128 vmultiplier =
178           _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
179 
180       vacc_lo = _mm_cvtps_epi32(
181                     _mm_mul_ps(
182                       _mm_cvtepi32_ps(vacc_lo),
183                       vmultiplier
184                       )
185                     );
186       vacc_hi = _mm_cvtps_epi32(
187                     _mm_mul_ps(
188                       _mm_cvtepi32_ps(vacc_hi),
189                       vmultiplier
190                       )
191                     );
192 
193       const __m128i voutput_zero_point = _mm_load_si128(
194           (const __m128i*)quantization_params->sse2.output_zero_point);
195       __m128i vout =
196           _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), voutput_zero_point);
197       vout = _mm_packus_epi16(vout, vout);
198       vout = _mm_min_epu8(
199           vout,
200           _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
201       vout = _mm_max_epu8(
202           vout,
203           _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
204 
205       _mm_storel_epi64((__m128i*)output, vout);
206       output += 8;
207     }
208     if (c != 0) {
209       const size_t i_predecrement = 8 - c;
210       const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
211       i0 -= i_predecrement;
212       i1 -= i_predecrement;
213       i2 -= i_predecrement;
214       i3 -= i_predecrement;
215       i4 -= i_predecrement;
216       i5 -= i_predecrement;
217       i6 -= i_predecrement;
218       i7 -= i_predecrement;
219       i8 -= i_predecrement;
220 
221       __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
222       __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
223 
224       const __m128i vi0 =
225           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
226       const __m128i vxi0 =
227           sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
228       const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
229       const __m128i vxk0 =
230           _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
231       const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
232       const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
233       vacc_lo =
234           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
235       vacc_hi =
236           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
237 
238       const __m128i vi1 =
239           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
240       const __m128i vxi1 =
241           sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
242       const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
243       const __m128i vxk1 =
244           _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
245       const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
246       const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
247       vacc_lo =
248           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
249       vacc_hi =
250           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
251 
252       const __m128i vi2 =
253           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
254       const __m128i vxi2 =
255           sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
256       const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
257       const __m128i vxk2 =
258           _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
259       const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
260       const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
261       vacc_lo =
262           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
263       vacc_hi =
264           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
265 
266       const __m128i vi3 =
267           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
268       const __m128i vxi3 =
269           sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
270       const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
271       const __m128i vxk3 =
272           _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
273       const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
274       const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
275       vacc_lo =
276           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
277       vacc_hi =
278           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
279 
280       const __m128i vi4 =
281           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
282       const __m128i vxi4 =
283           sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
284       const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
285       const __m128i vxk4 =
286           _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
287       const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
288       const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
289       vacc_lo =
290           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
291       vacc_hi =
292           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
293 
294       const __m128i vi5 =
295           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
296       const __m128i vxi5 =
297           sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
298       const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
299       const __m128i vxk5 =
300           _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
301       const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
302       const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
303       vacc_lo =
304           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
305       vacc_hi =
306           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
307 
308       const __m128i vi6 =
309           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
310       const __m128i vxi6 =
311           sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
312       const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
313       const __m128i vxk6 =
314           _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
315       const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
316       const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
317       vacc_lo =
318           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
319       vacc_hi =
320           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
321 
322       const __m128i vi7 =
323           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vi_shift);
324       const __m128i vxi7 =
325           sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
326       const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
327       const __m128i vxk7 =
328           _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
329       const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
330       const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
331       vacc_lo =
332           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
333       vacc_hi =
334           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
335 
336       const __m128i vi8 =
337           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vi_shift);
338       const __m128i vxi8 =
339           sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
340       const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
341       const __m128i vxk8 =
342           _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
343       const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
344       const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
345       vacc_lo =
346           _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
347       vacc_hi =
348           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
349 
350       const __m128 vmultiplier =
351           _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
352 
353       vacc_lo = _mm_cvtps_epi32(
354                     _mm_mul_ps(
355                       _mm_cvtepi32_ps(vacc_lo),
356                       vmultiplier
357                       )
358                     );
359       vacc_hi = _mm_cvtps_epi32(
360                     _mm_mul_ps(
361                       _mm_cvtepi32_ps(vacc_hi),
362                       vmultiplier
363                       )
364                     );
365 
366       const __m128i voutput_zero_point = _mm_load_si128(
367           (const __m128i*)quantization_params->sse2.output_zero_point);
368       __m128i vout =
369           _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), voutput_zero_point);
370       vout = _mm_packus_epi16(vout, vout);
371       vout = _mm_min_epu8(
372           vout,
373           _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
374       vout = _mm_max_epu8(
375           vout,
376           _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
377 
378       if (c & 4) {
379         *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
380         output += 4;
381         vout = _mm_srli_epi64(vout, 32);
382       }
383       if (c & 2) {
384         *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
385         output += 2;
386         vout = _mm_srli_epi32(vout, 16);
387       }
388       if (c & 1) {
389         *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
390         output += 1;
391       }
392     }
393 
394     output = (uint8_t*)((uintptr_t)output + output_increment);
395   } while (--output_width != 0);
396 }
397