xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/u8maxpool.h>
14 
pytorch_u8maxpool_ukernel_sub16__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC1])15 void pytorch_u8maxpool_ukernel_sub16__sse2(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     uint8_t* output,
21     size_t input_increment,
22     size_t output_increment,
23     const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) {
24   assert(n != 0);
25   assert(ks != 0);
26   assert(kc != 0);
27   assert(kc < 16);
28 
29   const __m128i voutput_max =
30       _mm_load_si128((const __m128i*)params->sse2.output_max);
31   const __m128i voutput_min =
32       _mm_load_si128((const __m128i*)params->sse2.output_min);
33 
34   do {
35     __m128i vmax = _mm_setzero_si128();
36 
37     size_t m = ks;
38     do {
39       const uint8_t* i = *input++;
40       i += kc;
41       __m128i vi = vmax;
42       if (kc & 1) {
43         i -= 1;
44         vi = _mm_cvtsi32_si128(*i);
45       }
46       if (kc & 2) {
47         vi = _mm_slli_epi32(vi, 16);
48         i -= 2;
49         vi = _mm_insert_epi16(vi, *((const uint16_t*)i), 0);
50       }
51       if (kc & 4) {
52         i -= 4;
53         vi = _mm_unpacklo_epi32(
54             _mm_cvtsi32_si128((int)*((const uint32_t*)i)), vi);
55       }
56       if (kc & 8) {
57         i -= 8;
58         vi = _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)i), vi);
59       }
60       vmax = _mm_max_epu8(vmax, vi);
61     } while (--m != 0);
62     input = (const uint8_t**)((uintptr_t)input + input_increment);
63     __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
64 
65     if (kc & 8) {
66       _mm_storel_epi64((__m128i*)output, vout);
67       output += 8;
68       vout = _mm_unpackhi_epi64(vout, vout);
69     }
70     if (kc & 4) {
71       *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
72       output += 4;
73       vout = _mm_srli_epi64(vout, 32);
74     }
75     if (kc & 2) {
76       *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
77       output += 2;
78       vout = _mm_srli_epi32(vout, 16);
79     }
80     if (kc & 1) {
81       *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
82       output += 1;
83     }
84     output = (uint8_t*)((uintptr_t)output + output_increment);
85   } while (--n != 0);
86 }
87