xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/q8avgpool.h>
14 
pytorch_q8avgpool_ukernel_up8xm__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8avgpool_ukernel_up8xm__sse2(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     const uint8_t* zero,
21     uint8_t* output,
22     size_t input_increment,
23     size_t output_increment,
24     const union pytorch_qnnp_avgpool_quantization_params
25         quantization_params[RESTRICT_STATIC 1]) {
26   assert(n != 0);
27   assert(ks != 0);
28   assert(kc < 8);
29 
30   const __m128i vbias =
31       _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
32   const __m128i vzero = _mm_setzero_si128();
33   const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
34 
35   do {
36     const uint8_t** next_input =
37         (const uint8_t**)((uintptr_t)input + input_increment);
38     __m128i vacc_lo = vbias;
39     __m128i vacc_hi = vbias;
40 
41     size_t m = ks;
42     do {
43       const uint8_t* i = *input++;
44       i += kc;
45       __m128i vi = _mm_setzero_si128();
46       if (kc & 1) {
47         i -= 1;
48         vi = _mm_cvtsi32_si128((int)(uint32_t)*i);
49       }
50       if (kc & 2) {
51         vi = _mm_slli_epi32(vi, 16);
52         i -= 2;
53         vi = _mm_insert_epi16(vi, *((const uint16_t*)i), 0);
54       }
55       if (kc & 4) {
56         i -= 4;
57         vi = _mm_unpacklo_epi32(
58             _mm_cvtsi32_si128((int)*((const uint32_t*)i)), vi);
59       }
60 
61       const __m128i vxi = _mm_unpacklo_epi8(vi, vzero);
62       vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi, vzero));
63       vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi, vzero));
64     } while (--m != 0);
65     input = next_input;
66 
67     const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
68     const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
69 
70     const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
71     const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
72 
73     __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
74     vout = _mm_adds_epi16(
75         vout,
76         _mm_load_si128(
77             (const __m128i*)quantization_params->sse2.output_zero_point));
78     vout = _mm_packus_epi16(vout, vout);
79     vout = _mm_min_epu8(
80         vout,
81         _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
82     vout = _mm_max_epu8(
83         vout,
84         _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
85 
86     if (kc & 4) {
87       *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
88       output += 4;
89       vout = _mm_srli_epi64(vout, 32);
90     }
91     if (kc & 2) {
92       *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
93       output += 2;
94       vout = _mm_srli_epi32(vout, 16);
95     }
96     if (kc & 1) {
97       *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
98       output += 1;
99     }
100     output = (uint8_t*)((uintptr_t)output + output_increment);
101   } while (--n != 0);
102 }
103