xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/u8clamp.h>
14 
pytorch_u8clamp_ukernel__neon(size_t n,const uint8_t * x,uint8_t * y,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8clamp_ukernel__neon(
16     size_t n,
17     const uint8_t* x,
18     uint8_t* y,
19     const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
20   assert(n != 0);
21 
22   const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.output_max);
23   const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.output_min);
24 
25   if
26     PYTORCH_QNNP_LIKELY(n >= 8) {
27       for (; n >= 64; n -= 64) {
28         const uint8x16_t vx0 = vld1q_u8(x);
29         x += 16;
30         const uint8x16_t vx1 = vld1q_u8(x);
31         x += 16;
32         const uint8x16_t vx2 = vld1q_u8(x);
33         x += 16;
34         const uint8x16_t vx3 = vld1q_u8(x);
35         x += 16;
36 
37         const uint8x16_t vy0 =
38             vminq_u8(vmaxq_u8(vx0, voutput_min), voutput_max);
39         const uint8x16_t vy1 =
40             vminq_u8(vmaxq_u8(vx1, voutput_min), voutput_max);
41         const uint8x16_t vy2 =
42             vminq_u8(vmaxq_u8(vx2, voutput_min), voutput_max);
43         const uint8x16_t vy3 =
44             vminq_u8(vmaxq_u8(vx3, voutput_min), voutput_max);
45 
46         __builtin_prefetch(x + 640);
47 
48         vst1q_u8(y, vy0);
49         y += 16;
50         vst1q_u8(y, vy1);
51         y += 16;
52         vst1q_u8(y, vy2);
53         y += 16;
54         vst1q_u8(y, vy3);
55         y += 16;
56       }
57       for (; n >= 8; n -= 8) {
58         uint8x8_t vout = vld1_u8(x);
59         x += 8;
60         vout = vmin_u8(vout, vget_low_u8(voutput_max));
61         vout = vmax_u8(vout, vget_low_u8(voutput_min));
62         vst1_u8(y, vout);
63         y += 8;
64       }
65       if (n != 0) {
66         const size_t n_increment = n - 8;
67         x = (const uint8_t*)((uintptr_t)x + n_increment);
68         y = (uint8_t*)((uintptr_t)y + n_increment);
69 
70         uint8x8_t vout = vld1_u8(x);
71         vout = vmin_u8(vout, vget_low_u8(voutput_max));
72         vout = vmax_u8(vout, vget_low_u8(voutput_min));
73         vst1_u8(y, vout);
74       }
75     }
76   else {
77     do {
78       uint8x8_t vout = vld1_dup_u8(x);
79       x += 1;
80       vout = vmin_u8(vout, vget_low_u8(voutput_max));
81       vout = vmax_u8(vout, vget_low_u8(voutput_min));
82       vst1_lane_u8(y, vout, 0);
83       y += 1;
84     } while (--n != 0);
85   }
86 }
87