xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/q8gavgpool.h>
14 
pytorch_q8gavgpool_ukernel_up8x7__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8gavgpool_ukernel_up8x7__neon(
16     size_t m,
17     size_t n,
18     const uint8_t* input,
19     size_t input_stride,
20     const uint8_t* zero,
21     uint8_t* output,
22     const union pytorch_qnnp_avgpool_quantization_params
23         quantization_params[restrict static 1]) {
24   assert(m >= 1);
25   assert(m <= 7);
26   assert(n >= 8);
27 
28   const uint8_t* i0 = input;
29   const uint8_t* i1 = i0 + input_stride;
30   if (m < 2) {
31     i1 = zero;
32   }
33   const uint8_t* i2 = i1 + input_stride;
34   if (m <= 2) {
35     i2 = zero;
36   }
37   const uint8_t* i3 = i2 + input_stride;
38   if (m < 4) {
39     i3 = zero;
40   }
41   const uint8_t* i4 = i3 + input_stride;
42   if (m <= 4) {
43     i4 = zero;
44   }
45   const uint8_t* i5 = i4 + input_stride;
46   if (m < 6) {
47     i5 = zero;
48   }
49   const uint8_t* i6 = i5 + input_stride;
50   if (m <= 6) {
51     i6 = zero;
52   }
53   const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
54   const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale);
55 #if defined(__aarch64__)
56   const int16x8_t voutput_zero_point =
57       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
58   const uint8x8_t voutput_min =
59       vld1_dup_u8(&quantization_params->neon.output_min);
60   const uint8x8_t voutput_max =
61       vld1_dup_u8(&quantization_params->neon.output_max);
62 #else
63   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
64   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
65   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
66   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
67 #endif
68 
69   do {
70     const uint8x8_t vi0 = vld1_u8(i0);
71     i0 += 8;
72     const uint8x8_t vi1 = vld1_u8(i1);
73     i1 += 8;
74     const uint8x8_t vi2 = vld1_u8(i2);
75     i2 += 8;
76     const uint8x8_t vi3 = vld1_u8(i3);
77     i3 += 8;
78     const uint8x8_t vi4 = vld1_u8(i4);
79     i4 += 8;
80     const uint8x8_t vi5 = vld1_u8(i5);
81     i5 += 8;
82     const uint8x8_t vi6 = vld1_u8(i6);
83     i6 += 8;
84 
85     const int16x8_t vsum016 =
86         vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6));
87     const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
88     const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
89 
90     int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23));
91     int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23));
92     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45));
93     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45));
94     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016));
95     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016));
96 
97     float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
98     float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
99 
100     vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
101     vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
102 
103 #if defined(__aarch64__)
104     vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
105     vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
106     const int16x8_t vacc = vqaddq_s16(
107         vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
108     uint8x8_t vout = vqmovun_s16(vacc);
109     vout = vmax_u8(vout, voutput_min);
110     vout = vmin_u8(vout, voutput_max);
111 #else
112     vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
113     vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
114 
115     vacc_lo = vsubq_s32(
116         vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
117     vacc_hi = vsubq_s32(
118         vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
119     const int16x8_t vacc =
120         vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
121     uint8x8_t vout = vqmovun_s16(vacc);
122 #endif
123 
124     vst1_u8(output, vout);
125     output += 8;
126 
127     n -= 8;
128   } while (n >= 8);
129   if (n != 0) {
130     const size_t address_increment = n - 8;
131     i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
132     i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
133     i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
134     i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
135     i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
136     i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
137     i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
138     const int64x1_t vshift = vmov_n_s64(8 * address_increment);
139 
140     const uint8x8_t vi0 =
141         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
142     const uint8x8_t vi1 =
143         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
144     const uint8x8_t vi2 =
145         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
146     const uint8x8_t vi3 =
147         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
148     const uint8x8_t vi4 =
149         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
150     const uint8x8_t vi5 =
151         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
152     const uint8x8_t vi6 =
153         vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
154 
155     const int16x8_t vsum016 =
156         vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6));
157     const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
158     const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
159 
160     int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23));
161     int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23));
162     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45));
163     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45));
164     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016));
165     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016));
166 
167     float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
168     float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
169 
170     vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
171     vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
172 
173 #if defined(__aarch64__)
174     vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
175     vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
176     const int16x8_t vacc = vqaddq_s16(
177         vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
178     uint8x8_t vout = vqmovun_s16(vacc);
179     vout = vmax_u8(vout, voutput_min);
180     vout = vmin_u8(vout, voutput_max);
181 #else
182     vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
183     vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
184 
185     vacc_lo = vsubq_s32(
186         vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
187     vacc_hi = vsubq_s32(
188         vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
189     const int16x8_t vacc =
190         vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
191     uint8x8_t vout = vqmovun_s16(vacc);
192 #endif
193 
194     if (n & 4) {
195       vst1_lane_u32(
196           __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
197       output += 4;
198       vout = vext_u8(vout, vout, 4);
199     }
200     if (n & 2) {
201       vst1_lane_u16(
202           __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
203       output += 2;
204       vout = vext_u8(vout, vout, 2);
205     }
206     if (n & 1) {
207       vst1_lane_u8(output, vout, 0);
208     }
209   }
210 }
211