1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/u8clamp.h>
14
pytorch_u8clamp_ukernel__neon(size_t n,const uint8_t * x,uint8_t * y,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8clamp_ukernel__neon(
16 size_t n,
17 const uint8_t* x,
18 uint8_t* y,
19 const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
20 assert(n != 0);
21
22 const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
23 const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
24
25 if
26 PYTORCH_QNNP_LIKELY(n >= 8) {
27 for (; n >= 64; n -= 64) {
28 const uint8x16_t vx0 = vld1q_u8(x);
29 x += 16;
30 const uint8x16_t vx1 = vld1q_u8(x);
31 x += 16;
32 const uint8x16_t vx2 = vld1q_u8(x);
33 x += 16;
34 const uint8x16_t vx3 = vld1q_u8(x);
35 x += 16;
36
37 const uint8x16_t vy0 =
38 vminq_u8(vmaxq_u8(vx0, voutput_min), voutput_max);
39 const uint8x16_t vy1 =
40 vminq_u8(vmaxq_u8(vx1, voutput_min), voutput_max);
41 const uint8x16_t vy2 =
42 vminq_u8(vmaxq_u8(vx2, voutput_min), voutput_max);
43 const uint8x16_t vy3 =
44 vminq_u8(vmaxq_u8(vx3, voutput_min), voutput_max);
45
46 __builtin_prefetch(x + 640);
47
48 vst1q_u8(y, vy0);
49 y += 16;
50 vst1q_u8(y, vy1);
51 y += 16;
52 vst1q_u8(y, vy2);
53 y += 16;
54 vst1q_u8(y, vy3);
55 y += 16;
56 }
57 for (; n >= 8; n -= 8) {
58 uint8x8_t vout = vld1_u8(x);
59 x += 8;
60 vout = vmin_u8(vout, vget_low_u8(voutput_max));
61 vout = vmax_u8(vout, vget_low_u8(voutput_min));
62 vst1_u8(y, vout);
63 y += 8;
64 }
65 if (n != 0) {
66 const size_t n_increment = n - 8;
67 x = (const uint8_t*)((uintptr_t)x + n_increment);
68 y = (uint8_t*)((uintptr_t)y + n_increment);
69
70 uint8x8_t vout = vld1_u8(x);
71 vout = vmin_u8(vout, vget_low_u8(voutput_max));
72 vout = vmax_u8(vout, vget_low_u8(voutput_min));
73 vst1_u8(y, vout);
74 }
75 }
76 else {
77 do {
78 uint8x8_t vout = vld1_dup_u8(x);
79 x += 1;
80 vout = vmin_u8(vout, vget_low_u8(voutput_max));
81 vout = vmax_u8(vout, vget_low_u8(voutput_min));
82 vst1_lane_u8(y, vout, 0);
83 y += 1;
84 } while (--n != 0);
85 }
86 }
87