1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/q8gavgpool.h>
14
pytorch_q8gavgpool_ukernel_up8x7__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8gavgpool_ukernel_up8x7__neon(
16 size_t m,
17 size_t n,
18 const uint8_t* input,
19 size_t input_stride,
20 const uint8_t* zero,
21 uint8_t* output,
22 const union pytorch_qnnp_avgpool_quantization_params
23 quantization_params[restrict static 1]) {
24 assert(m >= 1);
25 assert(m <= 7);
26 assert(n >= 8);
27
28 const uint8_t* i0 = input;
29 const uint8_t* i1 = i0 + input_stride;
30 if (m < 2) {
31 i1 = zero;
32 }
33 const uint8_t* i2 = i1 + input_stride;
34 if (m <= 2) {
35 i2 = zero;
36 }
37 const uint8_t* i3 = i2 + input_stride;
38 if (m < 4) {
39 i3 = zero;
40 }
41 const uint8_t* i4 = i3 + input_stride;
42 if (m <= 4) {
43 i4 = zero;
44 }
45 const uint8_t* i5 = i4 + input_stride;
46 if (m < 6) {
47 i5 = zero;
48 }
49 const uint8_t* i6 = i5 + input_stride;
50 if (m <= 6) {
51 i6 = zero;
52 }
53 const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
54 const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale);
55 #if defined(__aarch64__)
56 const int16x8_t voutput_zero_point =
57 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
58 const uint8x8_t voutput_min =
59 vld1_dup_u8(&quantization_params->neon.output_min);
60 const uint8x8_t voutput_max =
61 vld1_dup_u8(&quantization_params->neon.output_max);
62 #else
63 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
64 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
65 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
66 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
67 #endif
68
69 do {
70 const uint8x8_t vi0 = vld1_u8(i0);
71 i0 += 8;
72 const uint8x8_t vi1 = vld1_u8(i1);
73 i1 += 8;
74 const uint8x8_t vi2 = vld1_u8(i2);
75 i2 += 8;
76 const uint8x8_t vi3 = vld1_u8(i3);
77 i3 += 8;
78 const uint8x8_t vi4 = vld1_u8(i4);
79 i4 += 8;
80 const uint8x8_t vi5 = vld1_u8(i5);
81 i5 += 8;
82 const uint8x8_t vi6 = vld1_u8(i6);
83 i6 += 8;
84
85 const int16x8_t vsum016 =
86 vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6));
87 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
88 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
89
90 int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23));
91 int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23));
92 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45));
93 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45));
94 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016));
95 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016));
96
97 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
98 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
99
100 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
101 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
102
103 #if defined(__aarch64__)
104 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
105 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
106 const int16x8_t vacc = vqaddq_s16(
107 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
108 uint8x8_t vout = vqmovun_s16(vacc);
109 vout = vmax_u8(vout, voutput_min);
110 vout = vmin_u8(vout, voutput_max);
111 #else
112 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
113 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
114
115 vacc_lo = vsubq_s32(
116 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
117 vacc_hi = vsubq_s32(
118 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
119 const int16x8_t vacc =
120 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
121 uint8x8_t vout = vqmovun_s16(vacc);
122 #endif
123
124 vst1_u8(output, vout);
125 output += 8;
126
127 n -= 8;
128 } while (n >= 8);
129 if (n != 0) {
130 const size_t address_increment = n - 8;
131 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
132 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
133 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
134 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
135 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
136 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
137 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
138 const int64x1_t vshift = vmov_n_s64(8 * address_increment);
139
140 const uint8x8_t vi0 =
141 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
142 const uint8x8_t vi1 =
143 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
144 const uint8x8_t vi2 =
145 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
146 const uint8x8_t vi3 =
147 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
148 const uint8x8_t vi4 =
149 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
150 const uint8x8_t vi5 =
151 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
152 const uint8x8_t vi6 =
153 vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
154
155 const int16x8_t vsum016 =
156 vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6));
157 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
158 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
159
160 int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23));
161 int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23));
162 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45));
163 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45));
164 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016));
165 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016));
166
167 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
168 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
169
170 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
171 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
172
173 #if defined(__aarch64__)
174 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
175 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
176 const int16x8_t vacc = vqaddq_s16(
177 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
178 uint8x8_t vout = vqmovun_s16(vacc);
179 vout = vmax_u8(vout, voutput_min);
180 vout = vmin_u8(vout, voutput_max);
181 #else
182 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
183 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
184
185 vacc_lo = vsubq_s32(
186 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
187 vacc_hi = vsubq_s32(
188 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
189 const int16x8_t vacc =
190 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
191 uint8x8_t vout = vqmovun_s16(vacc);
192 #endif
193
194 if (n & 4) {
195 vst1_lane_u32(
196 __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
197 output += 4;
198 vout = vext_u8(vout, vout, 4);
199 }
200 if (n & 2) {
201 vst1_lane_u16(
202 __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
203 output += 2;
204 vout = vext_u8(vout, vout, 2);
205 }
206 if (n & 1) {
207 vst1_lane_u8(output, vout, 0);
208 }
209 }
210 }
211