1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/u8maxpool.h>
14
pytorch_u8maxpool_ukernel_sub16__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8maxpool_ukernel_sub16__neon(
16 size_t n,
17 size_t ks,
18 size_t kc,
19 const uint8_t** input,
20 uint8_t* output,
21 size_t input_increment,
22 size_t output_increment,
23 const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
24 assert(n != 0);
25 assert(ks != 0);
26 assert(kc != 0);
27 assert(kc < 16);
28
29 const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
30 const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
31 do {
32 uint8x16_t vmax = vmovq_n_u8(0);
33
34 size_t m = ks;
35 do {
36 const uint8_t* i = *input++;
37 i += kc;
38 uint8x16_t vi = vmax;
39 if (kc & 1) {
40 i -= 1;
41 vi = vld1q_lane_u8(i, vi, 0);
42 }
43 if (kc & 2) {
44 vi = vextq_u8(vi, vi, 14);
45 i -= 2;
46 vi = vreinterpretq_u8_u16(vld1q_lane_u16(
47 __builtin_assume_aligned(i, 1), vreinterpretq_u16_u8(vi), 0));
48 }
49 if (kc & 4) {
50 vi = vextq_u8(vi, vi, 12);
51 i -= 4;
52 vi = vreinterpretq_u8_u32(vld1q_lane_u32(
53 __builtin_assume_aligned(i, 1), vreinterpretq_u32_u8(vi), 0));
54 }
55 if (kc & 8) {
56 i -= 8;
57 vi = vcombine_u8(vld1_u8(i), vget_low_u8(vi));
58 }
59 vmax = vmaxq_u8(vmax, vi);
60 } while (--m != 0);
61 input = (const uint8_t**)((uintptr_t)input + input_increment);
62
63 vmax = vminq_u8(vmax, voutput_max);
64 vmax = vmaxq_u8(vmax, voutput_min);
65
66 uint8x8_t vout = vget_low_u8(vmax);
67 if (kc & 8) {
68 vst1_u8(output, vout);
69 output += 8;
70 vout = vget_high_u8(vmax);
71 }
72 if (kc & 4) {
73 vst1_lane_u32(
74 __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
75 output += 4;
76 vout = vext_u8(vout, vout, 4);
77 }
78 if (kc & 2) {
79 vst1_lane_u16(
80 __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
81 output += 2;
82 vout = vext_u8(vout, vout, 2);
83 }
84 if (kc & 1) {
85 vst1_lane_u8(output, vout, 0);
86 output += 1;
87 }
88 output = (uint8_t*)((uintptr_t)output + output_increment);
89
90 } while (--n != 0);
91 }
92