1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <arm_neon.h>
10
11 #include <qnnpack/q8gemm.h>
12
pytorch_q8sumrows_ukernel_4x__neon(const uint8_t * restrict a,size_t m,size_t k,size_t stride,const int32_t multiplier,int32_t * restrict a_sum)13 void pytorch_q8sumrows_ukernel_4x__neon(
14 const uint8_t* restrict a,
15 size_t m,
16 size_t k,
17 size_t stride,
18 const int32_t multiplier,
19 int32_t* restrict a_sum) {
20 const uint8_t* a0 = a;
21 const uint8_t* a1 = a0;
22 if (m >= 2) {
23 a1 += stride;
24 }
25 const uint8_t* a2 = a1;
26 if (m > 2) {
27 a2 += stride;
28 }
29 const uint8_t* a3 = a2;
30 if (m == 4) {
31 a3 += stride;
32 }
33
34 uint32x4_t vacc0x0123 = vmovq_n_u32(0); // row 0
35 uint32x4_t vacc1x0123 = vmovq_n_u32(0); // row 1
36 uint32x4_t vacc2x0123 = vmovq_n_u32(0); // row 2
37 uint32x4_t vacc3x0123 = vmovq_n_u32(0); // row 3
38 for (; k >= 16; k -= 16) {
39 // row 0
40 const uint8x16_t va0x0_15 = vld1q_u8(a0);
41 a0 += 16;
42 vacc0x0123 = vpadalq_u16(
43 vacc0x0123, vaddl_u8(vget_low_u8(va0x0_15), vget_high_u8(va0x0_15)));
44
45 // row 1
46 const uint8x16_t va1x0_15 = vld1q_u8(a1);
47 a1 += 16;
48 vacc1x0123 = vpadalq_u16(
49 vacc1x0123, vaddl_u8(vget_low_u8(va1x0_15), vget_high_u8(va1x0_15)));
50
51 // row 2
52 const uint8x16_t va2x0_15 = vld1q_u8(a2);
53 a2 += 16;
54 vacc2x0123 = vpadalq_u16(
55 vacc2x0123, vaddl_u8(vget_low_u8(va2x0_15), vget_high_u8(va2x0_15)));
56
57 // row 3
58 const uint8x16_t va3x0_15 = vld1q_u8(a3);
59 a3 += 16;
60 vacc3x0123 = vpadalq_u16(
61 vacc3x0123, vaddl_u8(vget_low_u8(va3x0_15), vget_high_u8(va3x0_15)));
62 }
63
64 if (k >= 8) {
65 vacc0x0123 = vaddw_u16(vacc0x0123, vpaddl_u8(vld1_u8(a0)));
66 a0 += 8;
67 vacc1x0123 = vaddw_u16(vacc1x0123, vpaddl_u8(vld1_u8(a1)));
68 a1 += 8;
69 vacc2x0123 = vaddw_u16(vacc2x0123, vpaddl_u8(vld1_u8(a2)));
70 a2 += 8;
71 vacc3x0123 = vaddw_u16(vacc3x0123, vpaddl_u8(vld1_u8(a3)));
72 a3 += 8;
73 k -= 8;
74 }
75
76 if (k >= 4) {
77 vacc0x0123 = vaddw_u16(
78 vacc0x0123,
79 vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
80 vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a0, 1))))));
81 a0 += 4;
82 vacc1x0123 = vaddw_u16(
83 vacc1x0123,
84 vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
85 vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a1, 1))))));
86 a1 += 4;
87 vacc2x0123 = vaddw_u16(
88 vacc2x0123,
89 vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
90 vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a2, 1))))));
91 a2 += 4;
92 vacc3x0123 = vaddw_u16(
93 vacc3x0123,
94 vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
95 vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a3, 1))))));
96 a3 += 4;
97 k -= 4;
98 }
99
100 const uint32x2_t vsum0x01 =
101 vpadd_u32(vget_low_u32(vacc0x0123), vget_high_u32(vacc0x0123));
102 const uint32x2_t vsum1x01 =
103 vpadd_u32(vget_low_u32(vacc1x0123), vget_high_u32(vacc1x0123));
104 const uint32x2_t vsum2x01 =
105 vpadd_u32(vget_low_u32(vacc2x0123), vget_high_u32(vacc2x0123));
106 const uint32x2_t vsum3x01 =
107 vpadd_u32(vget_low_u32(vacc3x0123), vget_high_u32(vacc3x0123));
108 uint32x4_t vacc0123 = vcombine_u32(
109 vpadd_u32(vsum0x01, vsum1x01), vpadd_u32(vsum2x01, vsum3x01));
110
111 if (k >= 2) {
112 const uint8x8_t va0x01010101 = vreinterpret_u8_u16(
113 vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
114 a0 += 2;
115 const uint8x8_t va1x01010101 = vreinterpret_u8_u16(
116 vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
117 a1 += 2;
118 const uint8x8_t va2x01010101 = vreinterpret_u8_u16(
119 vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
120 a2 += 2;
121 const uint8x8_t va3x01010101 = vreinterpret_u8_u16(
122 vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
123 a3 += 2;
124 const uint8x8_t va0x01_1x010101 = vext_u8(va0x01010101, va1x01010101, 2);
125 const uint8x8_t va2x01_3x010101 = vext_u8(va2x01010101, va3x01010101, 6);
126 const uint8x8_t va0123x01 = vext_u8(va0x01_1x010101, va2x01_3x010101, 4);
127 vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(va0123x01));
128 k -= 2;
129 }
130
131 if (k > 0) {
132 uint8x8_t vax0x1x2x3 = vmov_n_u8(0);
133 vax0x1x2x3 = vld1_lane_u8(a0, vax0x1x2x3, 0);
134 vax0x1x2x3 = vld1_lane_u8(a1, vax0x1x2x3, 2);
135 vax0x1x2x3 = vld1_lane_u8(a2, vax0x1x2x3, 4);
136 vax0x1x2x3 = vld1_lane_u8(a3, vax0x1x2x3, 6);
137 vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(vax0x1x2x3));
138 }
139
140 int32x4_t vsum0123 = vmulq_n_s32(vreinterpretq_s32_u32(vacc0123), multiplier);
141 if (m == 4) {
142 vst1q_s32(a_sum, vsum0123);
143 } else {
144 if (m >= 2) {
145 vst1_s32(a_sum, vget_low_s32(vsum0123));
146 a_sum += 2;
147 vsum0123 = vextq_s32(vsum0123, vsum0123, 2);
148 m -= 2;
149 }
150 if (m != 0) {
151 vst1q_lane_s32(a_sum, vsum0123, 0);
152 }
153 }
154 }
155