1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include <cstddef>
26 #include <arm_neon.h>
27 
28 namespace arm_conv {
29 namespace winograd {
30 namespace weight_transform {
31 
arm_fp32_2x2_5x5(unsigned int n_channels,const float * inptr,const size_t ld_weight_row,const size_t ld_weight_col,float * outptr,const size_t matrix_stride)32 void arm_fp32_2x2_5x5(
33   unsigned int n_channels,
34   const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col,
35   float *outptr, const size_t matrix_stride
36 )
37 {
38 #ifdef __aarch64__
39   // For each output channel
40   for (; n_channels >= 4; n_channels -= 4)
41   {
42     // Matrices used and computed in this kernel
43     float32x4_t w[5][5], Ww[6][5], V[6][6];
44 
45     // Read weights
46     for (int i = 0; i < 5; i++)
47     {
48       for (int j = 0; j < 5; j++)
49       {
50         w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col);
51       }
52     }
53 
54     // Compute the matrix W w
55     for (int j = 0; j < 5; j++)
56     {
57       // Ww[0][j] = w[0][j]/4.0f;
58       Ww[0][j] = vmulq_n_f32(w[0][j], 1.0f/4.0f);
59 
60       // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
61       Ww[1][j] = vmulq_n_f32(
62         vaddq_f32(
63           vaddq_f32(
64             vaddq_f32(w[1][j], w[0][j]),
65             vaddq_f32(w[3][j], w[2][j])
66           ),
67           w[4][j]
68         ),
69         -1.0f/6.0f
70       );
71 
72       // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
73       // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
74       Ww[2][j] = vmulq_n_f32(
75         vsubq_f32(
76           vaddq_f32(
77             vsubq_f32(w[1][j], w[0][j]),
78             vsubq_f32(w[3][j], w[2][j])
79           ),
80           w[4][j]
81         ),
82         1.0f/6.0f
83       );
84 
85       // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
86       Ww[3][j] = vmulq_n_f32(
87         vmlaq_n_f32(
88           vaddq_f32(
89             vaddq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
90             vaddq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
91           ),
92           w[4][j], 2.0f
93         ),
94         1.0f/3.0f
95       );
96 
97       // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
98       Ww[4][j] = vmulq_n_f32(
99         vmlaq_n_f32(
100           vaddq_f32(
101             vsubq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
102             vsubq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
103           ),
104           w[4][j], 2.0f
105         ),
106         1.0f/3.0f
107       );
108 
109       // Ww[5][j] = w[4][j];
110       Ww[5][j] = w[4][j];
111     }
112 
113     // Compute V = W w WT
114     for (int i = 0; i < 6; i++)
115     {
116       // V[i][0] = Ww[i][0]/4.0f;
117       V[i][0] = vmulq_n_f32(Ww[i][0], 1.0f/4.0f);
118 
119       // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
120       V[i][1] = vmulq_n_f32(
121         vaddq_f32(
122           vaddq_f32(
123             vaddq_f32(Ww[i][1], Ww[i][0]),
124             vaddq_f32(Ww[i][3], Ww[i][2])
125           ),
126           Ww[i][4]
127         ),
128         -1.0f/6.0f
129       );
130 
131       // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
132       // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
133       V[i][2] = vmulq_n_f32(
134         vsubq_f32(
135           vaddq_f32(
136             vsubq_f32(Ww[i][1], Ww[i][0]),
137             vsubq_f32(Ww[i][3], Ww[i][2])
138           ),
139           Ww[i][4]
140         ),
141         1.0f/6.0f
142       );
143 
144       // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
145       V[i][3] = vmulq_n_f32(
146         vmlaq_n_f32(
147           vaddq_f32(
148             vaddq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
149             vaddq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
150           ),
151           Ww[i][4], 2.0f
152         ),
153         1.0f/3.0f
154       );
155 
156       // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
157       V[i][4] = vmulq_n_f32(
158         vmlaq_n_f32(
159           vaddq_f32(
160             vsubq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
161             vsubq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
162           ),
163           Ww[i][4], 2.0f
164         ),
165         1.0f/3.0f
166       );
167 
168       // V[i][5] = Ww[i][4];
169       V[i][5] = Ww[i][4];
170     }
171 
172     // Store the transformed weights
173     for (int i = 0, m = 0; i < 6; i++)
174     {
175       for (int j = 0; j < 6; j++, m++)
176       {
177         vst1q_f32(outptr + m*matrix_stride, V[i][j]);
178       }
179     }
180 
181     inptr += 4;
182     outptr += 4;
183   }
184 #endif // __aarch64__
185   for (; n_channels >= 2; n_channels -= 2)
186   {
187     // Matrices used and computed in this kernel
188     float32x2_t w[5][5], Ww[6][5], V[6][6];
189 
190     // Read weights
191     for (int i = 0; i < 5; i++)
192     {
193       for (int j = 0; j < 5; j++)
194       {
195         w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col);
196       }
197     }
198 
199     // Compute the matrix W w
200     for (int j = 0; j < 5; j++)
201     {
202       // Ww[0][j] = w[0][j]/4.0f;
203       Ww[0][j] = vmul_n_f32(w[0][j], 1.0f/4.0f);
204 
205       // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
206       Ww[1][j] = vmul_n_f32(
207         vadd_f32(
208           vadd_f32(
209             vadd_f32(w[1][j], w[0][j]),
210             vadd_f32(w[3][j], w[2][j])
211           ),
212           w[4][j]
213         ),
214         -1.0f/6.0f
215       );
216 
217       // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
218       // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
219       Ww[2][j] = vmul_n_f32(
220         vsub_f32(
221           vadd_f32(
222             vsub_f32(w[1][j], w[0][j]),
223             vsub_f32(w[3][j], w[2][j])
224           ),
225           w[4][j]
226         ),
227         1.0f/6.0f
228       );
229 
230       // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
231       Ww[3][j] = vmul_n_f32(
232         vmla_n_f32(
233           vadd_f32(
234             vadd_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
235             vadd_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
236           ),
237           w[4][j], 2.0f
238         ),
239         1.0f/3.0f
240       );
241 
242       // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
243       Ww[4][j] = vmul_n_f32(
244         vmla_n_f32(
245           vadd_f32(
246             vsub_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
247             vsub_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
248           ),
249           w[4][j], 2.0f
250         ),
251         1.0f/3.0f
252       );
253 
254       // Ww[5][j] = w[4][j];
255       Ww[5][j] = w[4][j];
256     }
257 
258     // Compute V = W w WT
259     for (int i = 0; i < 6; i++)
260     {
261       // V[i][0] = Ww[i][0]/4.0f;
262       V[i][0] = vmul_n_f32(Ww[i][0], 1.0f/4.0f);
263 
264       // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
265       V[i][1] = vmul_n_f32(
266         vadd_f32(
267           vadd_f32(
268             vadd_f32(Ww[i][1], Ww[i][0]),
269             vadd_f32(Ww[i][3], Ww[i][2])
270           ),
271           Ww[i][4]
272         ),
273         -1.0f/6.0f
274       );
275 
276       // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
277       // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
278       V[i][2] = vmul_n_f32(
279         vsub_f32(
280           vadd_f32(
281             vsub_f32(Ww[i][1], Ww[i][0]),
282             vsub_f32(Ww[i][3], Ww[i][2])
283           ),
284           Ww[i][4]
285         ),
286         1.0f/6.0f
287       );
288 
289       // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
290       V[i][3] = vmul_n_f32(
291         vmla_n_f32(
292           vadd_f32(
293             vadd_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
294             vadd_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
295           ),
296           Ww[i][4], 2.0f
297         ),
298         1.0f/3.0f
299       );
300 
301       // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
302       V[i][4] = vmul_n_f32(
303         vmla_n_f32(
304           vadd_f32(
305             vsub_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
306             vsub_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
307           ),
308           Ww[i][4], 2.0f
309         ),
310         1.0f/3.0f
311       );
312 
313       // V[i][5] = Ww[i][4];
314       V[i][5] = Ww[i][4];
315     }
316 
317     // Store the transformed weights
318     for (int i = 0, m = 0; i < 6; i++)
319     {
320       for (int j = 0; j < 6; j++, m++)
321       {
322         vst1_f32(outptr + m*matrix_stride, V[i][j]);
323       }
324     }
325 
326     inptr += 2;
327     outptr += 2;
328   }
329   for (; n_channels; n_channels--)
330   {
331     // Matrices used and computed in this kernel
332     float w[5][5], Ww[6][5], V[6][6];
333 
334     // Read weights
335     for (int i = 0; i < 5; i++)
336     {
337       for (int j = 0; j < 5; j++)
338       {
339         w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
340       }
341     }
342 
343     // Compute the matrix W w
344     for (int j = 0; j < 5; j++)
345     {
346       Ww[0][j] = w[0][j]/4.0f;
347       Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
348       Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
349       Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
350       Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
351       Ww[5][j] = w[4][j];
352     }
353 
354     // Compute V = W w WT
355     for (int i = 0; i < 6; i++)
356     {
357       V[i][0] = Ww[i][0]/4.0f;
358       V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
359       V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
360       V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
361       V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
362       V[i][5] = Ww[i][4];
363     }
364 
365     // Store the transformed weights
366     for (int i = 0, m = 0; i < 6; i++)
367     {
368       for (int j = 0; j < 6; j++, m++)
369       {
370         *(outptr + m*matrix_stride) = V[i][j];
371       }
372     }
373 
374     inptr++;
375     outptr++;
376   }
377 }
378 
379 }  // namespace weight_transform
380 }  // namespace winograd
381 }  // namespace arm_conv
382