xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/u8clamp.h>
14 
pytorch_u8clamp_ukernel__sse2(size_t n,const uint8_t * x,uint8_t * y,const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC1])15 void pytorch_u8clamp_ukernel__sse2(
16     size_t n,
17     const uint8_t* x,
18     uint8_t* y,
19     const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) {
20   assert(n != 0);
21 
22   if
23     PYTORCH_QNNP_LIKELY(n >= 8) {
24       const __m128i voutput_max =
25           _mm_load_si128((const __m128i*)&params->sse2.output_max);
26       const __m128i voutput_min =
27           _mm_load_si128((const __m128i*)&params->sse2.output_min);
28       for (; n >= 64; n -= 64) {
29         const __m128i vx0 = _mm_loadu_si128((const __m128i*)x);
30         const __m128i vx1 = _mm_loadu_si128((const __m128i*)x + 1);
31         const __m128i vx2 = _mm_loadu_si128((const __m128i*)x + 2);
32         const __m128i vx3 = _mm_loadu_si128((const __m128i*)x + 3);
33         x += 64;
34 
35         const __m128i vy0 =
36             _mm_min_epu8(_mm_max_epu8(vx0, voutput_min), voutput_max);
37         const __m128i vy1 =
38             _mm_min_epu8(_mm_max_epu8(vx1, voutput_min), voutput_max);
39         const __m128i vy2 =
40             _mm_min_epu8(_mm_max_epu8(vx2, voutput_min), voutput_max);
41         const __m128i vy3 =
42             _mm_min_epu8(_mm_max_epu8(vx3, voutput_min), voutput_max);
43 
44         __builtin_prefetch(x + 640);
45 
46         _mm_storeu_si128((__m128i*)y, vy0);
47         _mm_storeu_si128((__m128i*)y + 1, vy1);
48         _mm_storeu_si128((__m128i*)y + 2, vy2);
49         _mm_storeu_si128((__m128i*)y + 3, vy3);
50         y += 64;
51       }
52       for (; n >= 8; n -= 8) {
53         __m128i vout = _mm_loadl_epi64((const __m128i*)x);
54         x += 8;
55         vout = _mm_min_epu8(vout, voutput_max);
56         vout = _mm_max_epu8(vout, voutput_min);
57         _mm_storel_epi64((__m128i*)y, vout);
58         y += 8;
59       }
60       if (n != 0) {
61         const size_t n_increment = n - 8;
62         x = (const uint8_t*)((uintptr_t)x + n_increment);
63         y = (uint8_t*)((uintptr_t)y + n_increment);
64 
65         __m128i vout = _mm_loadl_epi64((const __m128i*)x);
66         vout = _mm_min_epu8(vout, voutput_max);
67         vout = _mm_max_epu8(vout, voutput_min);
68         _mm_storel_epi64((__m128i*)y, vout);
69       }
70     }
71   else {
72     const uint32_t voutput_max = params->sse2.output_max[0];
73     const uint32_t voutput_min = params->sse2.output_min[0];
74     do {
75       uint32_t vout = *x++;
76       vout = vout > voutput_max ? voutput_max : vout;
77       vout = vout < voutput_min ? voutput_min : vout;
78       *y++ = (uint8_t)vout;
79     } while (--n != 0);
80   }
81 }
82