1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
25
26 #include <cstddef>
27 #include <arm_neon.h>
28
29 namespace arm_conv {
30 namespace winograd {
31 namespace weight_transform {
32
a64_fp16_4x4_3x3(unsigned int n_channels,const __fp16 * inptr,const size_t ld_weight_row,const size_t ld_weight_col,__fp16 * outptr,const size_t matrix_stride)33 void a64_fp16_4x4_3x3(
34 unsigned int n_channels,
35 const __fp16* inptr, // NOTE: Data in HWIO order
36 const size_t ld_weight_row,
37 const size_t ld_weight_col,
38 __fp16* outptr,
39 const size_t matrix_stride
40 )
41 {
42 #ifdef __aarch64__
43 for (; n_channels >= 8; n_channels -= 8)
44 {
45 // Matrices used and computed in this kernel
46 float16x8_t w[3][3], Ww[6][3], V[6][6];
47
48 // Read weights
49 for (int i = 0; i < 3; i++)
50 {
51 for (int j = 0; j < 3; j++)
52 {
53 w[i][j] = vld1q_f16(inptr + i*ld_weight_row + j*ld_weight_col);
54 }
55 }
56
57 // Compute the matrix W w
58 for (int j = 0; j < 3; j++)
59 {
60 // Ww[0][j] = 6*w[0][j];
61 Ww[0][j] = vmulq_n_f16(w[0][j], 6.0);
62
63 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
64 Ww[1][j] = vmulq_n_f16(vaddq_f16(vaddq_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
65
66 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
67 Ww[2][j] = vmulq_n_f16(vsubq_f16(vsubq_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
68
69 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
70 Ww[3][j] = vaddq_f16(vaddq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
71
72 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
73 Ww[4][j] = vaddq_f16(vsubq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
74
75 // Ww[5][j] = 24*w[2][j];
76 Ww[5][j] = vmulq_n_f16(w[2][j], 24.0f);
77 }
78
79 // Compute V = W w WT
80 for (int i = 0; i < 6; i++)
81 {
82 const float recip576 = 1.0f / 576.0f;
83
84 // V[i][0] = 6*Ww[i][0];
85 V[i][0] = vmulq_n_f16(vmulq_n_f16(Ww[i][0], 6.0), recip576);
86
87 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
88 V[i][1] = vmulq_n_f16(vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
89
90 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
91 V[i][2] = vmulq_n_f16(vmulq_n_f16(vsubq_f16(vsubq_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
92
93 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
94 V[i][3] = vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
95
96 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
97 V[i][4] = vmulq_n_f16(vaddq_f16(vsubq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
98
99 // V[i][5] = 24*Ww[i][2];
100 V[i][5] = vmulq_n_f16(vmulq_n_f16(Ww[i][2], 24.0f), recip576);
101 }
102
103 // Store the transformed weights
104 for (int i = 0, m = 0; i < 6; i++)
105 {
106 for (int j = 0; j < 6; j++, m++)
107 {
108 vst1q_f16(outptr + m*matrix_stride, V[i][j]);
109 }
110 }
111 inptr += 8;
112 outptr += 8;
113 }
114 #endif // __aarch64__
115 #ifdef __arm_any__
116 for (; n_channels >= 4; n_channels -= 4)
117 {
118 // Matrices used and computed in this kernel
119 float16x4_t w[3][3], Ww[6][3], V[6][6];
120
121 // Read weights
122 for (int i = 0; i < 3; i++)
123 {
124 for (int j = 0; j < 3; j++)
125 {
126 w[i][j] = vld1_f16(inptr + i*ld_weight_row + j*ld_weight_col);
127 }
128 }
129
130 // Compute the matrix W w
131 for (int j = 0; j < 3; j++)
132 {
133 // Ww[0][j] = 6*w[0][j];
134 Ww[0][j] = vmul_n_f16(w[0][j], 6.0);
135
136 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
137 Ww[1][j] = vmul_n_f16(vadd_f16(vadd_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
138
139 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
140 Ww[2][j] = vmul_n_f16(vsub_f16(vsub_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
141
142 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
143 Ww[3][j] = vadd_f16(vadd_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
144
145 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
146 Ww[4][j] = vadd_f16(vsub_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
147
148 // Ww[5][j] = 24*w[2][j];
149 Ww[5][j] = vmul_n_f16(w[2][j], 24.0f);
150 }
151
152 // Compute V = W w WT
153 for (int i = 0; i < 6; i++)
154 {
155 const float recip576 = 1.0f / 576.0f;
156
157 // V[i][0] = 6*Ww[i][0];
158 V[i][0] = vmul_n_f16(vmul_n_f16(Ww[i][0], 6.0), recip576);
159
160 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
161 V[i][1] = vmul_n_f16(vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
162
163 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
164 V[i][2] = vmul_n_f16(vmul_n_f16(vsub_f16(vsub_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
165
166 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
167 V[i][3] = vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
168
169 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
170 V[i][4] = vmul_n_f16(vadd_f16(vsub_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
171
172 // V[i][5] = 24*Ww[i][2];
173 V[i][5] = vmul_n_f16(vmul_n_f16(Ww[i][2], 24.0f), recip576);
174 }
175
176 // Store the transformed weights
177 for (int i = 0, m = 0; i < 6; i++)
178 {
179 for (int j = 0; j < 6; j++, m++)
180 {
181 vst1_f16(outptr + m*matrix_stride, V[i][j]);
182 }
183 }
184 inptr += 4;
185 outptr += 4;
186 }
187 #endif // __arm_any__
188 for (; n_channels; n_channels--)
189 {
190 // Matrices used and computed in this kernel
191 __fp16 w[3][3], Ww[6][3], V[6][6];
192
193 // Read weights
194 for (int i = 0; i < 3; i++)
195 {
196 for (int j = 0; j < 3; j++)
197 {
198 w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
199 }
200 }
201
202 // Compute the matrix W w
203 for (int j = 0; j < 3; j++)
204 {
205 Ww[0][j] = 6*w[0][j];
206 Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
207 Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
208 Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
209 Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
210 Ww[5][j] = 24*w[2][j];
211 }
212
213 // Compute V = W w WT
214 for (int i = 0; i < 6; i++)
215 {
216 V[i][0] = ( 6*Ww[i][0]) / 576.0;
217 V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
218 V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
219 V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
220 V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
221 V[i][5] = (24*Ww[i][2]) / 576.0;
222 }
223
224 // Store the transformed weights
225 for (int i = 0, m = 0; i < 6; i++)
226 {
227 for (int j = 0; j < 6; j++, m++)
228 {
229 *(outptr + m*matrix_stride) = V[i][j];
230 }
231 }
232
233 inptr++;
234 outptr++;
235 }
236 }
237
238 } // namespace weight_transform
239 } // namespace winograd
240 } // namespace arm_conv
241
242 #endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
243