xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/u8maxpool.h>
14 
pytorch_u8maxpool_ukernel_16x9p8q__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC1])15 void pytorch_u8maxpool_ukernel_16x9p8q__sse2(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     uint8_t* output,
21     size_t input_increment,
22     size_t output_increment,
23     const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) {
24   assert(n != 0);
25   assert(ks != 0);
26   assert(kc >= 16);
27 
28   const __m128i voutput_max =
29       _mm_load_si128((const __m128i*)params->sse2.output_max);
30   const __m128i voutput_min =
31       _mm_load_si128((const __m128i*)params->sse2.output_min);
32 
33   do {
34     uint8_t* o = output;
35     {
36       const uint8_t* i0 = *input++;
37       const uint8_t* i1 = *input++;
38       const uint8_t* i2 = *input++;
39       const uint8_t* i3 = *input++;
40       const uint8_t* i4 = *input++;
41       const uint8_t* i5 = *input++;
42       const uint8_t* i6 = *input++;
43       const uint8_t* i7 = *input++;
44       const uint8_t* i8 = *input++;
45       if (ks < 2) {
46         i1 = i0;
47       }
48       if (ks <= 2) {
49         i2 = i0;
50       }
51       if (ks < 4) {
52         i3 = i0;
53       }
54       if (ks <= 4) {
55         i4 = i0;
56       }
57       if (ks < 6) {
58         i5 = i0;
59       }
60       if (ks <= 6) {
61         i6 = i0;
62       }
63       if (ks < 8) {
64         i7 = i0;
65       }
66       if (ks <= 8) {
67         i8 = i0;
68       }
69 
70       size_t k = kc;
71       while (k >= 16) {
72         const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0);
73         i0 += 16;
74         const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1);
75         i1 += 16;
76         const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2);
77         i2 += 16;
78         const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3);
79         i3 += 16;
80         const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4);
81         i4 += 16;
82         const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5);
83         i5 += 16;
84         const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6);
85         i6 += 16;
86         const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7);
87         i7 += 16;
88         const __m128i vi8 = _mm_loadu_si128((const __m128i*)i8);
89         i8 += 16;
90 
91         const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8);
92         const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
93         const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
94         const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
95 
96         const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
97         const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67);
98         const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678);
99         const __m128i vout =
100             _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
101 
102         _mm_storeu_si128((__m128i*)o, vout);
103         o += 16;
104 
105         k -= 16;
106       }
107       if (k != 0) {
108         const size_t address_increment = k - 16;
109         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
110         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
111         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
112         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
113         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
114         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
115         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
116         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
117         i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
118         o = (uint8_t*)((uintptr_t)o + address_increment);
119 
120         const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0);
121         const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1);
122         const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2);
123         const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3);
124         const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4);
125         const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5);
126         const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6);
127         const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7);
128         const __m128i vi8 = _mm_loadu_si128((const __m128i*)i8);
129 
130         const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8);
131         const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
132         const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
133         const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
134 
135         const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
136         const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67);
137         const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678);
138         const __m128i vout =
139             _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
140 
141         _mm_storeu_si128((__m128i*)o, vout);
142         o += 16;
143       }
144     }
145 
146     for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) {
147       const uint8_t* i0 = *input++;
148       const uint8_t* i1 = *input++;
149       const uint8_t* i2 = *input++;
150       const uint8_t* i3 = *input++;
151       const uint8_t* i4 = *input++;
152       const uint8_t* i5 = *input++;
153       const uint8_t* i6 = *input++;
154       const uint8_t* i7 = *input++;
155       if (m < 2) {
156         i1 = i0;
157       }
158       if (m <= 2) {
159         i2 = i0;
160       }
161       if (m < 4) {
162         i3 = i0;
163       }
164       if (m <= 4) {
165         i4 = i0;
166       }
167       if (m < 6) {
168         i5 = i0;
169       }
170       if (m <= 6) {
171         i6 = i0;
172       }
173       if (m < 8) {
174         i7 = i0;
175       }
176 
177       o = output;
178       size_t k = kc;
179       while (k >= 16) {
180         const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0);
181         i0 += 16;
182         const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1);
183         i1 += 16;
184         const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2);
185         i2 += 16;
186         const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3);
187         i3 += 16;
188         const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4);
189         i4 += 16;
190         const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5);
191         i5 += 16;
192         const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6);
193         i6 += 16;
194         const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7);
195         i7 += 16;
196         const __m128i vo = _mm_loadu_si128((const __m128i*)o);
197 
198         const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo);
199         const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
200         const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
201         const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
202 
203         const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
204         const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67);
205         const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167);
206         const __m128i vout =
207             _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
208 
209         _mm_storeu_si128((__m128i*)o, vout);
210         o += 16;
211 
212         k -= 16;
213       }
214       if (k != 0) {
215         const size_t address_increment = k - 16;
216         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
217         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
218         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
219         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
220         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
221         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
222         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
223         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
224         o = (uint8_t*)((uintptr_t)o + address_increment);
225 
226         const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0);
227         const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1);
228         const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2);
229         const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3);
230         const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4);
231         const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5);
232         const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6);
233         const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7);
234         const __m128i vo = _mm_loadu_si128((const __m128i*)o);
235 
236         const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo);
237         const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
238         const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
239         const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
240 
241         const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
242         const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67);
243         const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167);
244         const __m128i vout =
245             _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
246 
247         _mm_storeu_si128((__m128i*)o, vout);
248         o += 16;
249       }
250     }
251     input = (const uint8_t**)((uintptr_t)input + input_increment);
252     output = (uint8_t*)((uintptr_t)o + output_increment);
253   } while (--n != 0);
254 }
255