xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/u8maxpool.h>
14 
pytorch_u8maxpool_ukernel_sub16__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8maxpool_ukernel_sub16__neon(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     uint8_t* output,
21     size_t input_increment,
22     size_t output_increment,
23     const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
24   assert(n != 0);
25   assert(ks != 0);
26   assert(kc != 0);
27   assert(kc < 16);
28 
29   const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.output_max);
30   const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.output_min);
31   do {
32     uint8x16_t vmax = vmovq_n_u8(0);
33 
34     size_t m = ks;
35     do {
36       const uint8_t* i = *input++;
37       i += kc;
38       uint8x16_t vi = vmax;
39       if (kc & 1) {
40         i -= 1;
41         vi = vld1q_lane_u8(i, vi, 0);
42       }
43       if (kc & 2) {
44         vi = vextq_u8(vi, vi, 14);
45         i -= 2;
46         vi = vreinterpretq_u8_u16(vld1q_lane_u16(
47             __builtin_assume_aligned(i, 1), vreinterpretq_u16_u8(vi), 0));
48       }
49       if (kc & 4) {
50         vi = vextq_u8(vi, vi, 12);
51         i -= 4;
52         vi = vreinterpretq_u8_u32(vld1q_lane_u32(
53             __builtin_assume_aligned(i, 1), vreinterpretq_u32_u8(vi), 0));
54       }
55       if (kc & 8) {
56         i -= 8;
57         vi = vcombine_u8(vld1_u8(i), vget_low_u8(vi));
58       }
59       vmax = vmaxq_u8(vmax, vi);
60     } while (--m != 0);
61     input = (const uint8_t**)((uintptr_t)input + input_increment);
62 
63     vmax = vminq_u8(vmax, voutput_max);
64     vmax = vmaxq_u8(vmax, voutput_min);
65 
66     uint8x8_t vout = vget_low_u8(vmax);
67     if (kc & 8) {
68       vst1_u8(output, vout);
69       output += 8;
70       vout = vget_high_u8(vmax);
71     }
72     if (kc & 4) {
73       vst1_lane_u32(
74           __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
75       output += 4;
76       vout = vext_u8(vout, vout, 4);
77     }
78     if (kc & 2) {
79       vst1_lane_u16(
80           __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
81       output += 2;
82       vout = vext_u8(vout, vout, 2);
83     }
84     if (kc & 1) {
85       vst1_lane_u8(output, vout, 0);
86       output += 1;
87     }
88     output = (uint8_t*)((uintptr_t)output + output_increment);
89 
90   } while (--n != 0);
91 }
92