1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include <cstddef>
26 #include <arm_neon.h>
27
28 namespace arm_conv {
29 namespace winograd {
30 namespace input_transform {
31
arm_fp32_4x4(const unsigned int n_channels,const float * input_base,const size_t input_row_stride,const size_t input_col_stride,float * outptr,const size_t matrix_stride)32 void arm_fp32_4x4(
33 const unsigned int n_channels,
34 const float *input_base,
35 const size_t input_row_stride,
36 const size_t input_col_stride,
37 float *outptr,
38 const size_t matrix_stride
39 )
40 {
41 constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
42
43 // Get pointers into the input tile
44 const float *x_ptrs[inner_tile_rows][inner_tile_cols];
45 for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
46 {
47 // Get a pointer into the row
48 const float* const row_ptr = input_base + xi*input_row_stride;
49
50 for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
51 {
52 x_ptrs[i][j] = row_ptr + xj*input_col_stride;
53 }
54 }
55
56 // Matrices used/computed in this kernel.
57 float x[inner_tile_rows][inner_tile_cols];
58 float XTx[inner_tile_rows][inner_tile_cols];
59 float U[inner_tile_rows][inner_tile_cols];
60
61 for (int i = 0; i < inner_tile_rows; i++)
62 {
63 for (int j = 0; j < inner_tile_cols; j++)
64 {
65 x[i][j] = XTx[i][j] = 0.0f;
66 }
67 }
68
69 // Perform the Winograd input transformation for each channel in the input
70 // tensor.
71 int channels_remaining = n_channels;
72 for (; channels_remaining >= 4; channels_remaining -= 4)
73 {
74 // Matrices used/computed in this kernel.
75 float32x4_t x[inner_tile_rows][inner_tile_cols];
76 float32x4_t XTx[inner_tile_rows][inner_tile_cols];
77 float32x4_t U[inner_tile_rows][inner_tile_cols];
78
79 for (int i = 0; i < inner_tile_rows; i++)
80 {
81 for (int j = 0; j < inner_tile_cols; j++)
82 {
83 x[i][j] = vdupq_n_f32(0.0f);
84 XTx[i][j] = vdupq_n_f32(0.0f);
85 }
86 }
87
88 // Load x
89 for (int i = 0; i < inner_tile_rows; i++)
90 {
91 for (int j = 0; j < inner_tile_cols; j++)
92 {
93 x[i][j] = vld1q_f32(x_ptrs[i][j]);
94 x_ptrs[i][j] += 4;
95 }
96 }
97
98 // Compute XT . x
99 for (int j = 0; j < inner_tile_cols; j++)
100 {
101 // XTx[0][j] = x[0][j] - x[2][j];
102 XTx[0][j] = vsubq_f32(x[0][j], x[2][j]);
103
104 // XTx[1][j] = x[1][j] + x[2][j];
105 XTx[1][j] = vaddq_f32(x[1][j], x[2][j]);
106
107 // XTx[2][j] = x[2][j] - x[1][j];
108 XTx[2][j] = vsubq_f32(x[2][j], x[1][j]);
109
110 // XTx[3][j] = x[1][j] - x[3][j];
111 XTx[3][j] = vsubq_f32(x[1][j], x[3][j]);
112 }
113
114 // Compute U = XT . x . X
115 for (int i = 0; i < inner_tile_rows; i++)
116 {
117 // U[i][0] = XTx[i][0] - XTx[i][2];
118 U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]);
119
120 // U[i][1] = XTx[i][1] + XTx[i][2];
121 U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]);
122
123 // U[i][2] = XTx[i][2] - XTx[i][1];
124 U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]);
125
126 // U[i][3] = XTx[i][1] - XTx[i][3];
127 U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]);
128 }
129
130 // Store the transformed matrix
131 for (int i = 0, m = 0; i < inner_tile_rows; i++)
132 {
133 for (int j = 0; j < inner_tile_cols; j++, m++)
134 {
135 vst1q_f32(outptr + m*matrix_stride, U[i][j]);
136 }
137 }
138 outptr += 4;
139 }
140 for (; channels_remaining >= 2; channels_remaining -= 2)
141 {
142 // Matrices used/computed in this kernel.
143 float32x2_t x[inner_tile_rows][inner_tile_cols];
144 float32x2_t XTx[inner_tile_rows][inner_tile_cols];
145 float32x2_t U[inner_tile_rows][inner_tile_cols];
146
147 for (int i = 0; i < inner_tile_rows; i++)
148 {
149 for (int j = 0; j < inner_tile_cols; j++)
150 {
151 x[i][j] = vdup_n_f32(0.0f);
152 XTx[i][j] = vdup_n_f32(0.0f);
153 }
154 }
155
156 // Load x
157 for (int i = 0; i < inner_tile_rows; i++)
158 {
159 for (int j = 0; j < inner_tile_cols; j++)
160 {
161 x[i][j] = vld1_f32(x_ptrs[i][j]);
162 x_ptrs[i][j] += 2;
163 }
164 }
165
166 // Compute XT . x
167 for (int j = 0; j < inner_tile_cols; j++)
168 {
169 // XTx[0][j] = x[0][j] - x[2][j];
170 XTx[0][j] = vsub_f32(x[0][j], x[2][j]);
171
172 // XTx[1][j] = x[1][j] + x[2][j];
173 XTx[1][j] = vadd_f32(x[1][j], x[2][j]);
174
175 // XTx[2][j] = x[2][j] - x[1][j];
176 XTx[2][j] = vsub_f32(x[2][j], x[1][j]);
177
178 // XTx[3][j] = x[1][j] - x[3][j];
179 XTx[3][j] = vsub_f32(x[1][j], x[3][j]);
180 }
181
182 // Compute U = XT . x . X
183 for (int i = 0; i < inner_tile_rows; i++)
184 {
185 // U[i][0] = XTx[i][0] - XTx[i][2];
186 U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]);
187
188 // U[i][1] = XTx[i][1] + XTx[i][2];
189 U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]);
190
191 // U[i][2] = XTx[i][2] - XTx[i][1];
192 U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]);
193
194 // U[i][3] = XTx[i][1] - XTx[i][3];
195 U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]);
196 }
197
198 // Store the transformed matrix
199 for (int i = 0, m = 0; i < inner_tile_rows; i++)
200 {
201 for (int j = 0; j < inner_tile_cols; j++, m++)
202 {
203 vst1_f32(outptr + m*matrix_stride, U[i][j]);
204 }
205 }
206 outptr += 2;
207 }
208 for (; channels_remaining; channels_remaining--)
209 {
210 // Load x
211 for (int i = 0; i < inner_tile_rows; i++)
212 {
213 for (int j = 0; j < inner_tile_cols; j++)
214 {
215 x[i][j] = *(x_ptrs[i][j]++);
216 }
217 }
218
219 // Compute XT . x
220 for (int j = 0; j < inner_tile_cols; j++)
221 {
222 XTx[0][j] = x[0][j] - x[2][j];
223 XTx[1][j] = x[1][j] + x[2][j];
224 XTx[2][j] = x[2][j] - x[1][j];
225 XTx[3][j] = x[1][j] - x[3][j];
226 }
227
228 // Compute U = XT . x . X
229 for (int i = 0; i < inner_tile_rows; i++)
230 {
231 U[i][0] = XTx[i][0] - XTx[i][2];
232 U[i][1] = XTx[i][1] + XTx[i][2];
233 U[i][2] = XTx[i][2] - XTx[i][1];
234 U[i][3] = XTx[i][1] - XTx[i][3];
235 }
236
237 // Store the transformed matrix
238 for (int i = 0, m = 0; i < inner_tile_rows; i++)
239 {
240 for (int j = 0; j < inner_tile_cols; j++, m++)
241 {
242 *(outptr + m*matrix_stride) = U[i][j];
243 }
244 }
245 outptr++;
246 }
247 }
248
249 } // namespace input_transform
250 } // namespace winograd
251 } // namespace arm_conv
252