1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <emmintrin.h>
12
13 #include <qnnpack/u8clamp.h>
14
pytorch_u8clamp_ukernel__sse2(size_t n,const uint8_t * x,uint8_t * y,const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC1])15 void pytorch_u8clamp_ukernel__sse2(
16 size_t n,
17 const uint8_t* x,
18 uint8_t* y,
19 const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) {
20 assert(n != 0);
21
22 if
23 PYTORCH_QNNP_LIKELY(n >= 8) {
24 const __m128i voutput_max =
25 _mm_load_si128((const __m128i*)¶ms->sse2.output_max);
26 const __m128i voutput_min =
27 _mm_load_si128((const __m128i*)¶ms->sse2.output_min);
28 for (; n >= 64; n -= 64) {
29 const __m128i vx0 = _mm_loadu_si128((const __m128i*)x);
30 const __m128i vx1 = _mm_loadu_si128((const __m128i*)x + 1);
31 const __m128i vx2 = _mm_loadu_si128((const __m128i*)x + 2);
32 const __m128i vx3 = _mm_loadu_si128((const __m128i*)x + 3);
33 x += 64;
34
35 const __m128i vy0 =
36 _mm_min_epu8(_mm_max_epu8(vx0, voutput_min), voutput_max);
37 const __m128i vy1 =
38 _mm_min_epu8(_mm_max_epu8(vx1, voutput_min), voutput_max);
39 const __m128i vy2 =
40 _mm_min_epu8(_mm_max_epu8(vx2, voutput_min), voutput_max);
41 const __m128i vy3 =
42 _mm_min_epu8(_mm_max_epu8(vx3, voutput_min), voutput_max);
43
44 __builtin_prefetch(x + 640);
45
46 _mm_storeu_si128((__m128i*)y, vy0);
47 _mm_storeu_si128((__m128i*)y + 1, vy1);
48 _mm_storeu_si128((__m128i*)y + 2, vy2);
49 _mm_storeu_si128((__m128i*)y + 3, vy3);
50 y += 64;
51 }
52 for (; n >= 8; n -= 8) {
53 __m128i vout = _mm_loadl_epi64((const __m128i*)x);
54 x += 8;
55 vout = _mm_min_epu8(vout, voutput_max);
56 vout = _mm_max_epu8(vout, voutput_min);
57 _mm_storel_epi64((__m128i*)y, vout);
58 y += 8;
59 }
60 if (n != 0) {
61 const size_t n_increment = n - 8;
62 x = (const uint8_t*)((uintptr_t)x + n_increment);
63 y = (uint8_t*)((uintptr_t)y + n_increment);
64
65 __m128i vout = _mm_loadl_epi64((const __m128i*)x);
66 vout = _mm_min_epu8(vout, voutput_max);
67 vout = _mm_max_epu8(vout, voutput_min);
68 _mm_storel_epi64((__m128i*)y, vout);
69 }
70 }
71 else {
72 const uint32_t voutput_max = params->sse2.output_max[0];
73 const uint32_t voutput_min = params->sse2.output_min[0];
74 do {
75 uint32_t vout = *x++;
76 vout = vout > voutput_max ? voutput_max : vout;
77 vout = vout < voutput_min ? voutput_min : vout;
78 *y++ = (uint8_t)vout;
79 } while (--n != 0);
80 }
81 }
82