1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include <cstddef>
26 #include <arm_neon.h>
27 
28 namespace arm_conv {
29 namespace winograd {
30 namespace weight_transform {
31 
arm_fp32_4x4_3x3(unsigned int n_channels,const float * inptr,const size_t ld_weight_row,const size_t ld_weight_col,float * outptr,const size_t matrix_stride)32 void arm_fp32_4x4_3x3(
33   unsigned int n_channels,
34   const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col,
35   float *outptr, const size_t matrix_stride
36 )
37 {
38 #ifdef __aarch64__
39   for (; n_channels >= 4; n_channels -= 4)
40   {
41     // Matrices used and computed in this kernel
42     float32x4_t w[3][3], Ww[6][3], V[6][6];
43 
44     // Read weights
45     for (int i = 0; i < 3; i++)
46     {
47       for (int j = 0; j < 3; j++)
48       {
49         w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col);
50       }
51     }
52 
53     // Compute the matrix W w
54     for (int j = 0; j < 3; j++)
55     {
56       // Ww[0][j] =  6*w[0][j];
57       Ww[0][j] = vmulq_n_f32(w[0][j], 6.0);
58 
59       // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
60       Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
61 
62       // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
63       Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
64 
65       // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
66       Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
67 
68       // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
69       Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
70 
71       // Ww[5][j] = 24*w[2][j];
72       Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f);
73     }
74 
75     // Compute V = W w WT
76     for (int i = 0; i < 6; i++)
77     {
78       const float recip576 = 1.0f / 576.0f;
79 
80       // V[i][0] =  6*Ww[i][0];
81       V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576);
82 
83       // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
84       V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
85 
86       // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
87       V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
88 
89       // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
90       V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
91 
92       // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
93       V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
94 
95       // V[i][5] = 24*Ww[i][2];
96       V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576);
97     }
98 
99     // Store the transformed weights
100     for (int i = 0, m = 0; i < 6; i++)
101     {
102       for (int j = 0; j < 6; j++, m++)
103       {
104         vst1q_f32(outptr + m*matrix_stride, V[i][j]);
105       }
106     }
107 
108     inptr += 4;
109     outptr += 4;
110   }
111 #endif // __aarch64__
112   for (; n_channels >= 2; n_channels -= 2)
113   {
114     // Matrices used and computed in this kernel
115     float32x2_t w[3][3], Ww[6][3], V[6][6];
116 
117     // Read weights
118     for (int i = 0; i < 3; i++)
119     {
120       for (int j = 0; j < 3; j++)
121       {
122         w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col);
123       }
124     }
125 
126     // Compute the matrix W w
127     for (int j = 0; j < 3; j++)
128     {
129       // Ww[0][j] =  6*w[0][j];
130       Ww[0][j] = vmul_n_f32(w[0][j], 6.0);
131 
132       // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
133       Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
134 
135       // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
136       Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
137 
138       // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
139       Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
140 
141       // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
142       Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
143 
144       // Ww[5][j] = 24*w[2][j];
145       Ww[5][j] = vmul_n_f32(w[2][j], 24.0f);
146     }
147 
148     // Compute V = W w WT
149     for (int i = 0; i < 6; i++)
150     {
151       const float recip576 = 1.0f / 576.0f;
152 
153       // V[i][0] =  6*Ww[i][0];
154       V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576);
155 
156       // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
157       V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
158 
159       // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
160       V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
161 
162       // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
163       V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
164 
165       // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
166       V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
167 
168       // V[i][5] = 24*Ww[i][2];
169       V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576);
170     }
171 
172     // Store the transformed weights
173     for (int i = 0, m = 0; i < 6; i++)
174     {
175       for (int j = 0; j < 6; j++, m++)
176       {
177         vst1_f32(outptr + m*matrix_stride, V[i][j]);
178       }
179     }
180 
181     inptr += 2;
182     outptr += 2;
183   }
184   for (; n_channels; n_channels--)
185   {
186     // Matrices used and computed in this kernel
187     float w[3][3], Ww[6][3], V[6][6];
188 
189     // Read weights
190     for (int i = 0; i < 3; i++)
191     {
192       for (int j = 0; j < 3; j++)
193       {
194         w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
195       }
196     }
197 
198     // Compute the matrix W w
199     for (int j = 0; j < 3; j++)
200     {
201       Ww[0][j] =  6*w[0][j];
202       Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
203       Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
204       Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
205       Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
206       Ww[5][j] = 24*w[2][j];
207     }
208 
209     // Compute V = W w WT
210     for (int i = 0; i < 6; i++)
211     {
212       V[i][0] = ( 6*Ww[i][0]) / 576.0;
213       V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
214       V[i][2] = (-4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
215       V[i][3] = ( 1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
216       V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
217       V[i][5] = (24*Ww[i][2]) / 576.0;
218     }
219 
220     // Store the transformed weights
221     for (int i = 0, m = 0; i < 6; i++)
222     {
223       for (int j = 0; j < 6; j++, m++)
224       {
225         *(outptr + m*matrix_stride) = V[i][j];
226       }
227     }
228 
229     inptr++;
230     outptr++;
231   }
232 }
233 
234 }  // namespace weight_transform
235 }  // namespace winograd
236 }  // namespace arm_conv
237