1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/maxpool.h>
11
12
xnn_s8_maxpool_minmax_ukernel_2p2x__neon_c16(size_t output_pixels,size_t kernel_elements,size_t channels,const int8_t ** input,size_t input_offset,int8_t * output,size_t input_increment,size_t output_increment,const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_s8_maxpool_minmax_ukernel_2p2x__neon_c16(
14 size_t output_pixels,
15 size_t kernel_elements,
16 size_t channels,
17 const int8_t** input,
18 size_t input_offset,
19 int8_t* output,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24 assert(output_pixels != 0);
25 assert(kernel_elements != 0);
26 assert(channels != 0);
27
28 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.max);
29 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.min);
30 do {
31 int8_t* o = output;
32 {
33 const int8_t* i0 = *input++;
34 const int8_t* i1 = *input++;
35 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
36 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
37 if (kernel_elements < 2) {
38 i1 = i0;
39 }
40
41 size_t c = channels;
42 for (; c >= 16; c -= 16) {
43 const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
44 const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
45
46 int8x16_t vout = vmaxq_s8(vi0, vi1);
47 vout = vmaxq_s8(vout, voutput_min);
48 vout = vminq_s8(vout, voutput_max);
49
50 vst1q_s8(o, vout); o += 16;
51 }
52 if (c != 0) {
53 const int8x16_t vi0 = vld1q_s8(i0);
54 const int8x16_t vi1 = vld1q_s8(i1);
55
56 int8x16_t vout = vmaxq_s8(vi0, vi1);
57 vout = vmaxq_s8(vout, voutput_min);
58 vout = vminq_s8(vout, voutput_max);
59
60 int8x8_t vout_lo = vget_low_s8(vout);
61 if (c & 8) {
62 vst1_s8(o, vout_lo); o += 8;
63 vout_lo = vget_high_s8(vout);
64 }
65 if (c & 4) {
66 vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
67 vout_lo = vext_s8(vout_lo, vout_lo, 4);
68 }
69 if (c & 2) {
70 vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
71 vout_lo = vext_s8(vout_lo, vout_lo, 2);
72 }
73 if (c & 1) {
74 vst1_lane_s8(o, vout_lo, 0); o += 1;
75 }
76 }
77 }
78
79 for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 2; k > 0; k -= 2) {
80 const int8_t* i0 = *input++;
81 const int8_t* i1 = *input++;
82 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
83 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
84 if (k < 2) {
85 i1 = i0;
86 }
87
88 o = output;
89 size_t c = channels;
90 for (; c >= 16; c -= 16) {
91 const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
92 const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
93 const int8x16_t vo = vld1q_s8(o);
94
95 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
96 int8x16_t vout = vmaxq_s8(vo, vmax01);
97 vout = vmaxq_s8(vout, voutput_min);
98 vout = vminq_s8(vout, voutput_max);
99
100 vst1q_s8(o, vout); o += 16;
101 }
102 if (c != 0) {
103 const int8x16_t vi0 = vld1q_s8(i0);
104 const int8x16_t vi1 = vld1q_s8(i1);
105 const int8x16_t vo = vld1q_s8(o);
106
107 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
108 int8x16_t vout = vmaxq_s8(vo, vmax01);
109 vout = vmaxq_s8(vout, voutput_min);
110 vout = vminq_s8(vout, voutput_max);
111
112 int8x8_t vout_lo = vget_low_s8(vout);
113 if (c & 8) {
114 vst1_s8(o, vout_lo); o += 8;
115 vout_lo = vget_high_s8(vout);
116 }
117 if (c & 4) {
118 vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
119 vout_lo = vext_s8(vout_lo, vout_lo, 4);
120 }
121 if (c & 2) {
122 vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
123 vout_lo = vext_s8(vout_lo, vout_lo, 2);
124 }
125 if (c & 1) {
126 vst1_lane_s8(o, vout_lo, 0); o += 1;
127 }
128 }
129 }
130 input = (const int8_t**) ((uintptr_t) input + input_increment);
131 output = (int8_t*) ((uintptr_t) o + output_increment);
132 } while (--output_pixels != 0);
133 }
134