1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/maxpool.h>
11
12
xnn_s8_maxpool_minmax_ukernel_4p3x__neon_c16(size_t output_pixels,size_t kernel_elements,size_t channels,const int8_t ** input,size_t input_offset,int8_t * output,size_t input_increment,size_t output_increment,const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_s8_maxpool_minmax_ukernel_4p3x__neon_c16(
14 size_t output_pixels,
15 size_t kernel_elements,
16 size_t channels,
17 const int8_t** input,
18 size_t input_offset,
19 int8_t* output,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24 assert(output_pixels != 0);
25 assert(kernel_elements != 0);
26 assert(channels != 0);
27
28 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.max);
29 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.min);
30 do {
31 int8_t* o = output;
32 {
33 const int8_t* i0 = *input++;
34 const int8_t* i1 = *input++;
35 const int8_t* i2 = *input++;
36 const int8_t* i3 = *input++;
37 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
38 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
39 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
40 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
41 if (kernel_elements < 2) {
42 i1 = i0;
43 }
44 if (kernel_elements <= 2) {
45 i2 = i0;
46 }
47 if (kernel_elements < 4) {
48 i3 = i0;
49 }
50
51 size_t c = channels;
52 for (; c >= 16; c -= 16) {
53 const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
54 const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
55 const int8x16_t vi2 = vld1q_s8(i2); i2 += 16;
56 const int8x16_t vi3 = vld1q_s8(i3); i3 += 16;
57
58 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
59 const int8x16_t vmax23 = vmaxq_s8(vi2, vi3);
60 int8x16_t vout = vmaxq_s8(vmax01, vmax23);
61 vout = vmaxq_s8(vout, voutput_min);
62 vout = vminq_s8(vout, voutput_max);
63
64 vst1q_s8(o, vout); o += 16;
65 }
66 if (c != 0) {
67 const int8x16_t vi0 = vld1q_s8(i0);
68 const int8x16_t vi1 = vld1q_s8(i1);
69 const int8x16_t vi2 = vld1q_s8(i2);
70 const int8x16_t vi3 = vld1q_s8(i3);
71
72 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
73 const int8x16_t vmax23 = vmaxq_s8(vi2, vi3);
74 int8x16_t vout = vmaxq_s8(vmax01, vmax23);
75 vout = vmaxq_s8(vout, voutput_min);
76 vout = vminq_s8(vout, voutput_max);
77
78 int8x8_t vout_lo = vget_low_s8(vout);
79 if (c & 8) {
80 vst1_s8(o, vout_lo); o += 8;
81 vout_lo = vget_high_s8(vout);
82 }
83 if (c & 4) {
84 vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
85 vout_lo = vext_s8(vout_lo, vout_lo, 4);
86 }
87 if (c & 2) {
88 vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
89 vout_lo = vext_s8(vout_lo, vout_lo, 2);
90 }
91 if (c & 1) {
92 vst1_lane_s8(o, vout_lo, 0); o += 1;
93 }
94 }
95 }
96
97 for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 4; k > 0; k -= 3) {
98 const int8_t* i0 = *input++;
99 const int8_t* i1 = *input++;
100 const int8_t* i2 = *input++;
101 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
102 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
103 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
104 if (k < 2) {
105 i1 = i0;
106 }
107 if (k <= 2) {
108 i2 = i0;
109 }
110
111 o = output;
112 size_t c = channels;
113 for (; c >= 16; c -= 16) {
114 const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
115 const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
116 const int8x16_t vi2 = vld1q_s8(i2); i2 += 16;
117 const int8x16_t vo = vld1q_s8(o);
118
119 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
120 const int8x16_t vmax2o = vmaxq_s8(vi2, vo);
121 int8x16_t vout = vmaxq_s8(vmax01, vmax2o);
122 vout = vmaxq_s8(vout, voutput_min);
123 vout = vminq_s8(vout, voutput_max);
124
125 vst1q_s8(o, vout); o += 16;
126 }
127 if (c != 0) {
128 const int8x16_t vi0 = vld1q_s8(i0);
129 const int8x16_t vi1 = vld1q_s8(i1);
130 const int8x16_t vi2 = vld1q_s8(i2);
131 const int8x16_t vo = vld1q_s8(o);
132
133 const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
134 const int8x16_t vmax2o = vmaxq_s8(vi2, vo);
135 int8x16_t vout = vmaxq_s8(vmax01, vmax2o);
136 vout = vmaxq_s8(vout, voutput_min);
137 vout = vminq_s8(vout, voutput_max);
138
139 int8x8_t vout_lo = vget_low_s8(vout);
140 if (c & 8) {
141 vst1_s8(o, vout_lo); o += 8;
142 vout_lo = vget_high_s8(vout);
143 }
144 if (c & 4) {
145 vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
146 vout_lo = vext_s8(vout_lo, vout_lo, 4);
147 }
148 if (c & 2) {
149 vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
150 vout_lo = vext_s8(vout_lo, vout_lo, 2);
151 }
152 if (c & 1) {
153 vst1_lane_s8(o, vout_lo, 0); o += 1;
154 }
155 }
156 }
157 input = (const int8_t**) ((uintptr_t) input + input_increment);
158 output = (int8_t*) ((uintptr_t) o + output_increment);
159 } while (--output_pixels != 0);
160 }
161