xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/q8gavgpool.h>
14 
pytorch_q8gavgpool_ukernel_up8x7__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_up8x7__sse2(
16     size_t m,
17     size_t n,
18     const uint8_t* input,
19     size_t input_stride,
20     const uint8_t* zero,
21     uint8_t* output,
22     const union pytorch_qnnp_avgpool_quantization_params
23         quantization_params[RESTRICT_STATIC 1]) {
24   assert(m >= 1);
25   assert(m <= 7);
26   assert(n >= 8);
27 
28   const uint8_t* i0 = input;
29   const uint8_t* i1 = i0 + input_stride;
30   if (m < 2) {
31     i1 = zero;
32   }
33   const uint8_t* i2 = i1 + input_stride;
34   if (m <= 2) {
35     i2 = zero;
36   }
37   const uint8_t* i3 = i2 + input_stride;
38   if (m < 4) {
39     i3 = zero;
40   }
41   const uint8_t* i4 = i3 + input_stride;
42   if (m <= 4) {
43     i4 = zero;
44   }
45   const uint8_t* i5 = i4 + input_stride;
46   if (m < 6) {
47     i5 = zero;
48   }
49   const uint8_t* i6 = i5 + input_stride;
50   if (m <= 6) {
51     i6 = zero;
52   }
53   const __m128i vbias =
54       _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
55   const __m128i vzero = _mm_setzero_si128();
56 
57   const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
58 
59   do {
60     const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
61     i0 += 8;
62     const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
63     i1 += 8;
64     const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
65     i2 += 8;
66     const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
67     i3 += 8;
68     const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
69     i4 += 8;
70     const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
71     i5 += 8;
72     const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
73     i6 += 8;
74 
75     const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
76     const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
77     const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
78     const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
79     const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
80     const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
81     const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
82 
83     __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
84     __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
85     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
86     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
87     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
88     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
89     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
90     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
91     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
92     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
93     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
94     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
95     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
96     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
97 
98     const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
99     const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
100 
101     const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
102     const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
103 
104     __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
105     vout = _mm_adds_epi16(
106         vout,
107         _mm_load_si128(
108             (const __m128i*)quantization_params->sse2.output_zero_point));
109     vout = _mm_packus_epi16(vout, vout);
110     vout = _mm_min_epu8(
111         vout,
112         _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
113     vout = _mm_max_epu8(
114         vout,
115         _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
116 
117     _mm_storel_epi64((__m128i*)output, vout);
118     output += 8;
119 
120     n -= 8;
121   } while (n >= 8);
122   if (n != 0) {
123     const size_t address_decrement = 8 - n;
124     i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
125     i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
126     i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
127     i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
128     i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
129     i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
130     i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
131     const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement);
132 
133     const __m128i vi0 =
134         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
135     const __m128i vi1 =
136         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
137     const __m128i vi2 =
138         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
139     const __m128i vi3 =
140         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
141     const __m128i vi4 =
142         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
143     const __m128i vi5 =
144         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
145     const __m128i vi6 =
146         _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
147 
148     const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
149     const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
150     const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
151     const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
152     const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
153     const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
154     const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
155 
156     __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
157     __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
158     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
159     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
160     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
161     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
162     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
163     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
164     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
165     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
166     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
167     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
168     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
169     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
170 
171     const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
172     const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
173 
174     const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
175     const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
176 
177     __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
178     vout = _mm_adds_epi16(
179         vout,
180         _mm_load_si128(
181             (const __m128i*)quantization_params->sse2.output_zero_point));
182     vout = _mm_packus_epi16(vout, vout);
183     vout = _mm_min_epu8(
184         vout,
185         _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
186     vout = _mm_max_epu8(
187         vout,
188         _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
189 
190     if (n & 4) {
191       *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
192       output += 4;
193       vout = _mm_srli_epi64(vout, 32);
194     }
195     if (n & 2) {
196       *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
197       output += 2;
198       vout = _mm_srli_epi32(vout, 16);
199     }
200     if (n & 1) {
201       *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
202     }
203   }
204 }
205