xref: /aosp_15_r20/external/XNNPACK/src/s8-maxpool/2p2x-minmax-neon-c16.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <arm_neon.h>
9 
10 #include <xnnpack/maxpool.h>
11 
12 
xnn_s8_maxpool_minmax_ukernel_2p2x__neon_c16(size_t output_pixels,size_t kernel_elements,size_t channels,const int8_t ** input,size_t input_offset,int8_t * output,size_t input_increment,size_t output_increment,const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_s8_maxpool_minmax_ukernel_2p2x__neon_c16(
14     size_t output_pixels,
15     size_t kernel_elements,
16     size_t channels,
17     const int8_t** input,
18     size_t input_offset,
19     int8_t* output,
20     size_t input_increment,
21     size_t output_increment,
22     const union xnn_s8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24   assert(output_pixels != 0);
25   assert(kernel_elements != 0);
26   assert(channels != 0);
27 
28   const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.max);
29   const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.min);
30   do {
31     int8_t* o = output;
32     {
33       const int8_t* i0 = *input++;
34       const int8_t* i1 = *input++;
35       i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
36       i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
37       if (kernel_elements < 2) {
38         i1 = i0;
39       }
40 
41       size_t c = channels;
42       for (; c >= 16; c -= 16) {
43         const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
44         const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
45 
46         int8x16_t vout = vmaxq_s8(vi0, vi1);
47         vout = vmaxq_s8(vout, voutput_min);
48         vout = vminq_s8(vout, voutput_max);
49 
50         vst1q_s8(o, vout); o += 16;
51       }
52       if (c != 0) {
53         const int8x16_t vi0 = vld1q_s8(i0);
54         const int8x16_t vi1 = vld1q_s8(i1);
55 
56         int8x16_t vout = vmaxq_s8(vi0, vi1);
57         vout = vmaxq_s8(vout, voutput_min);
58         vout = vminq_s8(vout, voutput_max);
59 
60         int8x8_t vout_lo = vget_low_s8(vout);
61         if (c & 8) {
62           vst1_s8(o, vout_lo); o += 8;
63           vout_lo = vget_high_s8(vout);
64         }
65         if (c & 4) {
66           vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
67           vout_lo = vext_s8(vout_lo, vout_lo, 4);
68         }
69         if (c & 2) {
70           vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
71           vout_lo = vext_s8(vout_lo, vout_lo, 2);
72         }
73         if (c & 1) {
74           vst1_lane_s8(o, vout_lo, 0); o += 1;
75         }
76       }
77     }
78 
79     for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 2; k > 0; k -= 2) {
80       const int8_t* i0 = *input++;
81       const int8_t* i1 = *input++;
82       i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
83       i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
84       if (k < 2) {
85         i1 = i0;
86       }
87 
88       o = output;
89       size_t c = channels;
90       for (; c >= 16; c -= 16) {
91         const int8x16_t vi0 = vld1q_s8(i0); i0 += 16;
92         const int8x16_t vi1 = vld1q_s8(i1); i1 += 16;
93         const int8x16_t vo = vld1q_s8(o);
94 
95         const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
96         int8x16_t vout = vmaxq_s8(vo, vmax01);
97         vout = vmaxq_s8(vout, voutput_min);
98         vout = vminq_s8(vout, voutput_max);
99 
100         vst1q_s8(o, vout); o += 16;
101       }
102       if (c != 0) {
103         const int8x16_t vi0 = vld1q_s8(i0);
104         const int8x16_t vi1 = vld1q_s8(i1);
105         const int8x16_t vo = vld1q_s8(o);
106 
107         const int8x16_t vmax01 = vmaxq_s8(vi0, vi1);
108         int8x16_t vout = vmaxq_s8(vo, vmax01);
109         vout = vmaxq_s8(vout, voutput_min);
110         vout = vminq_s8(vout, voutput_max);
111 
112         int8x8_t vout_lo = vget_low_s8(vout);
113         if (c & 8) {
114           vst1_s8(o, vout_lo); o += 8;
115           vout_lo = vget_high_s8(vout);
116         }
117         if (c & 4) {
118           vst1_lane_u32((void*) o, vreinterpret_u32_s8(vout_lo), 0); o += 4;
119           vout_lo = vext_s8(vout_lo, vout_lo, 4);
120         }
121         if (c & 2) {
122           vst1_lane_u16((void*) o, vreinterpret_u16_s8(vout_lo), 0); o += 2;
123           vout_lo = vext_s8(vout_lo, vout_lo, 2);
124         }
125         if (c & 1) {
126           vst1_lane_s8(o, vout_lo, 0); o += 1;
127         }
128       }
129     }
130     input = (const int8_t**) ((uintptr_t) input + input_increment);
131     output = (int8_t*) ((uintptr_t) o + output_increment);
132   } while (--output_pixels != 0);
133 }
134