1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include <algorithm>
26 #include <cstddef>
27 #include <arm_neon.h>
28
29 namespace arm_conv {
30 namespace winograd {
31 namespace output_transform {
32
arm_fp32_4x4_3x3(unsigned int n_channels,const float * inptr,const size_t matrix_stride,const float * bptr,float * outptr,const size_t output_row_stride,const size_t output_col_stride,const float output_min,const float output_max)33 void arm_fp32_4x4_3x3(
34 unsigned int n_channels,
35 const float* inptr,
36 const size_t matrix_stride,
37 const float* bptr,
38 float *outptr,
39 const size_t output_row_stride,
40 const size_t output_col_stride,
41 const float output_min,
42 const float output_max
43 )
44 {
45 constexpr auto output_tile_rows = 4u, output_tile_cols = 4u;
46
47 // For each channel of the output
48 for (; n_channels >= 4; n_channels -= 4)
49 {
50 // Matrices used and computed during this transform
51 float32x4_t F[6][6], FZ[6][4], f[4][4], b;
52
53 // Read a 6x6 tile in the Winograd domain
54 for (auto i = 0u, m = 0u; i < 6; i++)
55 {
56 for (auto j = 0u; j < 6; j++, m++)
57 {
58 F[i][j] = vld1q_f32(inptr + m*matrix_stride);
59 }
60 }
61 inptr += 4;
62
63 // Compute the matrix F Z
64 for (auto i = 0u; i < 6; i++)
65 {
66 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
67 FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
68
69 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
70 FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f);
71
72 // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
73 FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f);
74
75 // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
76 FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
77 }
78
79 // Compute the output tile f = ZT F Z
80 for (auto j = 0u; j < 4; j++)
81 {
82 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
83 f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
84
85 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
86 f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f);
87
88 // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
89 f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f);
90
91 // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
92 f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
93 }
94
95 // Write out the output tile
96 if (bptr != nullptr)
97 {
98 b = vld1q_f32(bptr);
99 bptr += 4;
100 }
101 else
102 {
103 b = vdupq_n_f32(0.0f);
104 }
105 for (auto i = 0u; i < output_tile_rows; i++)
106 {
107 for (auto j = 0u; j < output_tile_cols; j++)
108 {
109 const auto y =
110 vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
111 vdupq_n_f32(output_min));
112 vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y);
113 }
114 }
115 outptr += 4;
116 }
117 for (; n_channels >= 2; n_channels -= 2)
118 {
119 // Matrices used and computed during this transform
120 float32x2_t F[6][6], FZ[6][4], f[4][4], b;
121
122 // Read a 6x6 tile in the Winograd domain
123 for (auto i = 0u, m = 0u; i < 6; i++)
124 {
125 for (auto j = 0u; j < 6; j++, m++)
126 {
127 F[i][j] = vld1_f32(inptr + m*matrix_stride);
128 }
129 }
130 inptr += 2;
131
132 // Compute the matrix F Z
133 for (auto i = 0u; i < 6; i++)
134 {
135 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
136 FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]);
137
138 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
139 FZ[i][1] = vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f);
140
141 // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
142 FZ[i][2] = vmla_n_f32(vadd_f32(F[i][1], F[i][2]), vadd_f32(F[i][3], F[i][4]), 4.0f);
143
144 // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
145 FZ[i][3] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
146 }
147
148 // Compute the output tile f = ZT F Z
149 for (auto j = 0u; j < 4; j++)
150 {
151 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
152 f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
153
154 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
155 f[1][j] = vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f);
156
157 // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
158 f[2][j] = vmla_n_f32(vadd_f32(FZ[1][j], FZ[2][j]), vadd_f32(FZ[3][j], FZ[4][j]), 4.0f);
159
160 // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
161 f[3][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
162 }
163
164 // Write out the output tile
165 if (bptr != nullptr)
166 {
167 b = vld1_f32(bptr);
168 bptr += 2;
169 }
170 else
171 {
172 b = vdup_n_f32(0.0f);
173 }
174 for (auto i = 0u; i < output_tile_rows; i++)
175 {
176 for (auto j = 0u; j < output_tile_cols; j++)
177 {
178 const auto y =
179 vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
180 vdup_n_f32(output_min));
181 vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y);
182 }
183 }
184 outptr += 2;
185 }
186 for (; n_channels; n_channels--)
187 {
188 // Matrices used and computed during this transform
189 float F[6][6], FZ[6][4], f[4][4], b;
190
191 // Read a 6x6 tile in the Winograd domain
192 for (auto i = 0u, m = 0u; i < 6; i++)
193 {
194 for (auto j = 0u; j < 6; j++, m++)
195 {
196 F[i][j] = *(inptr + m*matrix_stride);
197 }
198 }
199 inptr++;
200
201 // Compute the matrix F Z
202 for (auto i = 0u; i < 6; i++)
203 {
204 FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
205 FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
206 FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
207 FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
208 }
209
210 // Compute the output tile f = ZT F Z
211 for (auto j = 0u; j < 4; j++)
212 {
213 f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
214 f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
215 f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
216 f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
217 }
218
219 // Write out the output tile
220 if (bptr != nullptr)
221 {
222 b = *(bptr++);
223 }
224 else
225 {
226 b = 0.0f;
227 }
228 for (auto i = 0u; i < output_tile_rows; i++)
229 {
230 for (auto j = 0u; j < output_tile_cols; j++)
231 {
232 const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
233 *(outptr + i*output_row_stride + j*output_col_stride) = y;
234 }
235 }
236 outptr++;
237 }
238 }
239
240 } // namespace output_transform
241 } // namespace winograd
242 } // namespace arm_conv
243