1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
25 
26 #include <cstddef>
27 #include <arm_neon.h>
28 
29 namespace arm_conv {
30 namespace winograd {
31 namespace weight_transform {
32 
a64_fp16_4x4_3x3(unsigned int n_channels,const __fp16 * inptr,const size_t ld_weight_row,const size_t ld_weight_col,__fp16 * outptr,const size_t matrix_stride)33 void a64_fp16_4x4_3x3(
34     unsigned int n_channels,
35     const __fp16* inptr,  // NOTE: Data in HWIO order
36     const size_t ld_weight_row,
37     const size_t ld_weight_col,
38     __fp16* outptr,
39     const size_t matrix_stride
40 )
41 {
42 #ifdef __aarch64__
43     for (; n_channels >= 8; n_channels -= 8)
44     {
45       // Matrices used and computed in this kernel
46       float16x8_t w[3][3], Ww[6][3], V[6][6];
47 
48       // Read weights
49       for (int i = 0; i < 3; i++)
50       {
51         for (int j = 0; j < 3; j++)
52         {
53           w[i][j] = vld1q_f16(inptr + i*ld_weight_row + j*ld_weight_col);
54         }
55       }
56 
57       // Compute the matrix W w
58       for (int j = 0; j < 3; j++)
59       {
60         // Ww[0][j] =  6*w[0][j];
61         Ww[0][j] = vmulq_n_f16(w[0][j], 6.0);
62 
63         // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
64         Ww[1][j] = vmulq_n_f16(vaddq_f16(vaddq_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
65 
66         // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
67         Ww[2][j] = vmulq_n_f16(vsubq_f16(vsubq_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
68 
69         // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
70         Ww[3][j] = vaddq_f16(vaddq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
71 
72         // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
73         Ww[4][j] = vaddq_f16(vsubq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
74 
75         // Ww[5][j] = 24*w[2][j];
76         Ww[5][j] = vmulq_n_f16(w[2][j], 24.0f);
77       }
78 
79       // Compute V = W w WT
80       for (int i = 0; i < 6; i++)
81       {
82         const float recip576 = 1.0f / 576.0f;
83 
84         // V[i][0] =  6*Ww[i][0];
85         V[i][0] = vmulq_n_f16(vmulq_n_f16(Ww[i][0], 6.0), recip576);
86 
87         // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
88         V[i][1] = vmulq_n_f16(vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
89 
90         // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
91         V[i][2] = vmulq_n_f16(vmulq_n_f16(vsubq_f16(vsubq_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
92 
93         // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
94         V[i][3] = vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
95 
96         // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
97         V[i][4] = vmulq_n_f16(vaddq_f16(vsubq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
98 
99         // V[i][5] = 24*Ww[i][2];
100         V[i][5] = vmulq_n_f16(vmulq_n_f16(Ww[i][2], 24.0f), recip576);
101       }
102 
103       // Store the transformed weights
104       for (int i = 0, m = 0; i < 6; i++)
105       {
106         for (int j = 0; j < 6; j++, m++)
107         {
108           vst1q_f16(outptr + m*matrix_stride, V[i][j]);
109         }
110       }
111       inptr += 8;
112       outptr += 8;
113     }
114 #endif  // __aarch64__
115 #ifdef __arm_any__
116     for (; n_channels >= 4; n_channels -= 4)
117     {
118       // Matrices used and computed in this kernel
119       float16x4_t w[3][3], Ww[6][3], V[6][6];
120 
121       // Read weights
122       for (int i = 0; i < 3; i++)
123       {
124         for (int j = 0; j < 3; j++)
125         {
126           w[i][j] = vld1_f16(inptr + i*ld_weight_row + j*ld_weight_col);
127         }
128       }
129 
130       // Compute the matrix W w
131       for (int j = 0; j < 3; j++)
132       {
133         // Ww[0][j] =  6*w[0][j];
134         Ww[0][j] = vmul_n_f16(w[0][j], 6.0);
135 
136         // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
137         Ww[1][j] = vmul_n_f16(vadd_f16(vadd_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
138 
139         // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
140         Ww[2][j] = vmul_n_f16(vsub_f16(vsub_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
141 
142         // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
143         Ww[3][j] = vadd_f16(vadd_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
144 
145         // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
146         Ww[4][j] = vadd_f16(vsub_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
147 
148         // Ww[5][j] = 24*w[2][j];
149         Ww[5][j] = vmul_n_f16(w[2][j], 24.0f);
150       }
151 
152       // Compute V = W w WT
153       for (int i = 0; i < 6; i++)
154       {
155         const float recip576 = 1.0f / 576.0f;
156 
157         // V[i][0] =  6*Ww[i][0];
158         V[i][0] = vmul_n_f16(vmul_n_f16(Ww[i][0], 6.0), recip576);
159 
160         // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
161         V[i][1] = vmul_n_f16(vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
162 
163         // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
164         V[i][2] = vmul_n_f16(vmul_n_f16(vsub_f16(vsub_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
165 
166         // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
167         V[i][3] = vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
168 
169         // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
170         V[i][4] = vmul_n_f16(vadd_f16(vsub_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
171 
172         // V[i][5] = 24*Ww[i][2];
173         V[i][5] = vmul_n_f16(vmul_n_f16(Ww[i][2], 24.0f), recip576);
174       }
175 
176       // Store the transformed weights
177       for (int i = 0, m = 0; i < 6; i++)
178       {
179         for (int j = 0; j < 6; j++, m++)
180         {
181           vst1_f16(outptr + m*matrix_stride, V[i][j]);
182         }
183       }
184       inptr += 4;
185       outptr += 4;
186     }
187 #endif  // __arm_any__
188     for (; n_channels; n_channels--)
189     {
190       // Matrices used and computed in this kernel
191       __fp16 w[3][3], Ww[6][3], V[6][6];
192 
193       // Read weights
194       for (int i = 0; i < 3; i++)
195       {
196         for (int j = 0; j < 3; j++)
197         {
198           w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
199         }
200       }
201 
202       // Compute the matrix W w
203       for (int j = 0; j < 3; j++)
204       {
205         Ww[0][j] =  6*w[0][j];
206         Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
207         Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
208         Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
209         Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
210         Ww[5][j] = 24*w[2][j];
211       }
212 
213       // Compute V = W w WT
214       for (int i = 0; i < 6; i++)
215       {
216         V[i][0] = ( 6*Ww[i][0]) / 576.0;
217         V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
218         V[i][2] = (-4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
219         V[i][3] = ( 1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
220         V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
221         V[i][5] = (24*Ww[i][2]) / 576.0;
222       }
223 
224       // Store the transformed weights
225       for (int i = 0, m = 0; i < 6; i++)
226       {
227         for (int j = 0; j < 6; j++, m++)
228         {
229           *(outptr + m*matrix_stride) = V[i][j];
230         }
231       }
232 
233       inptr++;
234       outptr++;
235     }
236 }
237 
238 }  // namespace weight_transform
239 }  // namespace winograd
240 }  // namespace arm_conv
241 
242 #endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
243