xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/q8gavgpool.h>
14 
pytorch_q8gavgpool_ukernel_up8xm__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8gavgpool_ukernel_up8xm__neon(
16     size_t m,
17     size_t n,
18     const uint8_t* input,
19     size_t input_stride,
20     const uint8_t* zero,
21     uint8_t* output,
22     const union pytorch_qnnp_avgpool_quantization_params
23         quantization_params[restrict static 1]) {
24   assert(m >= 1);
25   assert(n < 8);
26 
27   const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
28   int32x4_t vacc_lo = vbias;
29   int32x4_t vacc_hi = vbias;
30   while (m >= 8) {
31     const uint8x8_t vinput = vld1_u8(input);
32     input += input_stride;
33     const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput));
34     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput));
35     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput));
36 
37     m--;
38   }
39   while (m-- != 0) {
40     input += n;
41     uint8x8_t vinput = vmov_n_u8(0);
42     if (n & 1) {
43       input -= 1;
44       vinput = vld1_lane_u8(input, vinput, 0);
45     }
46     if (n & 2) {
47       vinput = vext_u8(vinput, vinput, 6);
48       input -= 2;
49       vinput = vreinterpret_u8_u16(vld1_lane_u16(
50           __builtin_assume_aligned(input, 1), vreinterpret_u16_u8(vinput), 0));
51     }
52     if (n & 4) {
53       vinput = vext_u8(vinput, vinput, 4);
54       input -= 4;
55       vinput = vreinterpret_u8_u32(vld1_lane_u32(
56           __builtin_assume_aligned(input, 1), vreinterpret_u32_u8(vinput), 0));
57     }
58     input += input_stride;
59 
60     const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput));
61     vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput));
62     vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput));
63   }
64 
65   const float32x4_t vscale =
66       vdupq_n_f32(quantization_params->neon.scale);
67   const int16x8_t voutput_zero_point =
68       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
69 
70   float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
71   float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
72 
73   vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
74   vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
75 
76 #if defined(__aarch64__)
77   const uint8x8_t voutput_min =
78       vld1_dup_u8(&quantization_params->neon.output_min);
79   const uint8x8_t voutput_max =
80       vld1_dup_u8(&quantization_params->neon.output_max);
81 
82   vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
83   vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
84   const int16x8_t vacc = vqaddq_s16(
85       vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
86   uint8x8_t vout = vqmovun_s16(vacc);
87   vout = vmax_u8(vout, voutput_min);
88   vout = vmin_u8(vout, voutput_max);
89 #else
90   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
91   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
92   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
93   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
94 
95   vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
96   vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
97 
98   vacc_lo = vsubq_s32(
99       vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
100   vacc_hi = vsubq_s32(
101       vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
102   const int16x8_t vacc =
103       vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
104   uint8x8_t vout = vqmovun_s16(vacc);
105 #endif
106 
107   if (n & 4) {
108     vst1_lane_u32(
109         __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
110     output += 4;
111     vout = vext_u8(vout, vout, 4);
112   }
113   if (n & 2) {
114     vst1_lane_u16(
115         __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
116     output += 2;
117     vout = vext_u8(vout, vout, 2);
118   }
119   if (n & 1) {
120     vst1_lane_u8(output, vout, 0);
121   }
122 }
123