1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/q8gavgpool.h>
14
pytorch_q8gavgpool_ukernel_up8xm__neon(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8gavgpool_ukernel_up8xm__neon(
16 size_t m,
17 size_t n,
18 const uint8_t* input,
19 size_t input_stride,
20 const uint8_t* zero,
21 uint8_t* output,
22 const union pytorch_qnnp_avgpool_quantization_params
23 quantization_params[restrict static 1]) {
24 assert(m >= 1);
25 assert(n < 8);
26
27 const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
28 int32x4_t vacc_lo = vbias;
29 int32x4_t vacc_hi = vbias;
30 while (m >= 8) {
31 const uint8x8_t vinput = vld1_u8(input);
32 input += input_stride;
33 const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput));
34 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput));
35 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput));
36
37 m--;
38 }
39 while (m-- != 0) {
40 input += n;
41 uint8x8_t vinput = vmov_n_u8(0);
42 if (n & 1) {
43 input -= 1;
44 vinput = vld1_lane_u8(input, vinput, 0);
45 }
46 if (n & 2) {
47 vinput = vext_u8(vinput, vinput, 6);
48 input -= 2;
49 vinput = vreinterpret_u8_u16(vld1_lane_u16(
50 __builtin_assume_aligned(input, 1), vreinterpret_u16_u8(vinput), 0));
51 }
52 if (n & 4) {
53 vinput = vext_u8(vinput, vinput, 4);
54 input -= 4;
55 vinput = vreinterpret_u8_u32(vld1_lane_u32(
56 __builtin_assume_aligned(input, 1), vreinterpret_u32_u8(vinput), 0));
57 }
58 input += input_stride;
59
60 const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput));
61 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput));
62 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput));
63 }
64
65 const float32x4_t vscale =
66 vdupq_n_f32(quantization_params->neon.scale);
67 const int16x8_t voutput_zero_point =
68 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
69
70 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
71 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
72
73 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
74 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
75
76 #if defined(__aarch64__)
77 const uint8x8_t voutput_min =
78 vld1_dup_u8(&quantization_params->neon.output_min);
79 const uint8x8_t voutput_max =
80 vld1_dup_u8(&quantization_params->neon.output_max);
81
82 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
83 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
84 const int16x8_t vacc = vqaddq_s16(
85 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
86 uint8x8_t vout = vqmovun_s16(vacc);
87 vout = vmax_u8(vout, voutput_min);
88 vout = vmin_u8(vout, voutput_max);
89 #else
90 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
91 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
92 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
93 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
94
95 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
96 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
97
98 vacc_lo = vsubq_s32(
99 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
100 vacc_hi = vsubq_s32(
101 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
102 const int16x8_t vacc =
103 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
104 uint8x8_t vout = vqmovun_s16(vacc);
105 #endif
106
107 if (n & 4) {
108 vst1_lane_u32(
109 __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
110 output += 4;
111 vout = vext_u8(vout, vout, 4);
112 }
113 if (n & 2) {
114 vst1_lane_u16(
115 __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
116 output += 2;
117 vout = vext_u8(vout, vout, 2);
118 }
119 if (n & 1) {
120 vst1_lane_u8(output, vout, 0);
121 }
122 }
123