1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include <cstddef>
26 #include <arm_neon.h>
27 
28 namespace arm_conv {
29 namespace winograd {
30 namespace input_transform {
31 
arm_fp32_4x4(const unsigned int n_channels,const float * input_base,const size_t input_row_stride,const size_t input_col_stride,float * outptr,const size_t matrix_stride)32 void arm_fp32_4x4(
33   const unsigned int n_channels,
34   const float *input_base,
35   const size_t input_row_stride,
36   const size_t input_col_stride,
37   float *outptr,
38   const size_t matrix_stride
39 )
40 {
41   constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
42 
43   // Get pointers into the input tile
44   const float *x_ptrs[inner_tile_rows][inner_tile_cols];
45   for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
46   {
47     // Get a pointer into the row
48     const float* const row_ptr = input_base + xi*input_row_stride;
49 
50     for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
51     {
52       x_ptrs[i][j] = row_ptr + xj*input_col_stride;
53     }
54   }
55 
56   // Matrices used/computed in this kernel.
57   float x[inner_tile_rows][inner_tile_cols];
58   float XTx[inner_tile_rows][inner_tile_cols];
59   float U[inner_tile_rows][inner_tile_cols];
60 
61   for (int i = 0; i < inner_tile_rows; i++)
62   {
63     for (int j = 0; j < inner_tile_cols; j++)
64     {
65       x[i][j] = XTx[i][j] = 0.0f;
66     }
67   }
68 
69   // Perform the Winograd input transformation for each channel in the input
70   // tensor.
71   int channels_remaining = n_channels;
72   for (; channels_remaining >= 4; channels_remaining -= 4)
73   {
74     // Matrices used/computed in this kernel.
75     float32x4_t x[inner_tile_rows][inner_tile_cols];
76     float32x4_t XTx[inner_tile_rows][inner_tile_cols];
77     float32x4_t U[inner_tile_rows][inner_tile_cols];
78 
79     for (int i = 0; i < inner_tile_rows; i++)
80     {
81       for (int j = 0; j < inner_tile_cols; j++)
82       {
83         x[i][j] = vdupq_n_f32(0.0f);
84         XTx[i][j] = vdupq_n_f32(0.0f);
85       }
86     }
87 
88     // Load x
89     for (int i = 0; i < inner_tile_rows; i++)
90     {
91       for (int j = 0; j < inner_tile_cols; j++)
92       {
93         x[i][j] = vld1q_f32(x_ptrs[i][j]);
94         x_ptrs[i][j] += 4;
95       }
96     }
97 
98     // Compute XT . x
99     for (int j = 0; j < inner_tile_cols; j++)
100     {
101       // XTx[0][j] = x[0][j] - x[2][j];
102       XTx[0][j] = vsubq_f32(x[0][j], x[2][j]);
103 
104       // XTx[1][j] = x[1][j] + x[2][j];
105       XTx[1][j] = vaddq_f32(x[1][j], x[2][j]);
106 
107       // XTx[2][j] = x[2][j] - x[1][j];
108       XTx[2][j] = vsubq_f32(x[2][j], x[1][j]);
109 
110       // XTx[3][j] = x[1][j] - x[3][j];
111       XTx[3][j] = vsubq_f32(x[1][j], x[3][j]);
112     }
113 
114     // Compute U = XT . x . X
115     for (int i = 0; i < inner_tile_rows; i++)
116     {
117       // U[i][0] = XTx[i][0] - XTx[i][2];
118       U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]);
119 
120       // U[i][1] = XTx[i][1] + XTx[i][2];
121       U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]);
122 
123       // U[i][2] = XTx[i][2] - XTx[i][1];
124       U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]);
125 
126       // U[i][3] = XTx[i][1] - XTx[i][3];
127       U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]);
128     }
129 
130     // Store the transformed matrix
131     for (int i = 0, m = 0; i < inner_tile_rows; i++)
132     {
133       for (int j = 0; j < inner_tile_cols; j++, m++)
134       {
135         vst1q_f32(outptr + m*matrix_stride, U[i][j]);
136       }
137     }
138     outptr += 4;
139   }
140   for (; channels_remaining >= 2; channels_remaining -= 2)
141   {
142     // Matrices used/computed in this kernel.
143     float32x2_t x[inner_tile_rows][inner_tile_cols];
144     float32x2_t XTx[inner_tile_rows][inner_tile_cols];
145     float32x2_t U[inner_tile_rows][inner_tile_cols];
146 
147     for (int i = 0; i < inner_tile_rows; i++)
148     {
149       for (int j = 0; j < inner_tile_cols; j++)
150       {
151         x[i][j] = vdup_n_f32(0.0f);
152         XTx[i][j] = vdup_n_f32(0.0f);
153       }
154     }
155 
156     // Load x
157     for (int i = 0; i < inner_tile_rows; i++)
158     {
159       for (int j = 0; j < inner_tile_cols; j++)
160       {
161         x[i][j] = vld1_f32(x_ptrs[i][j]);
162         x_ptrs[i][j] += 2;
163       }
164     }
165 
166     // Compute XT . x
167     for (int j = 0; j < inner_tile_cols; j++)
168     {
169       // XTx[0][j] = x[0][j] - x[2][j];
170       XTx[0][j] = vsub_f32(x[0][j], x[2][j]);
171 
172       // XTx[1][j] = x[1][j] + x[2][j];
173       XTx[1][j] = vadd_f32(x[1][j], x[2][j]);
174 
175       // XTx[2][j] = x[2][j] - x[1][j];
176       XTx[2][j] = vsub_f32(x[2][j], x[1][j]);
177 
178       // XTx[3][j] = x[1][j] - x[3][j];
179       XTx[3][j] = vsub_f32(x[1][j], x[3][j]);
180     }
181 
182     // Compute U = XT . x . X
183     for (int i = 0; i < inner_tile_rows; i++)
184     {
185       // U[i][0] = XTx[i][0] - XTx[i][2];
186       U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]);
187 
188       // U[i][1] = XTx[i][1] + XTx[i][2];
189       U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]);
190 
191       // U[i][2] = XTx[i][2] - XTx[i][1];
192       U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]);
193 
194       // U[i][3] = XTx[i][1] - XTx[i][3];
195       U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]);
196     }
197 
198     // Store the transformed matrix
199     for (int i = 0, m = 0; i < inner_tile_rows; i++)
200     {
201       for (int j = 0; j < inner_tile_cols; j++, m++)
202       {
203         vst1_f32(outptr + m*matrix_stride, U[i][j]);
204       }
205     }
206     outptr += 2;
207   }
208   for (; channels_remaining; channels_remaining--)
209   {
210     // Load x
211     for (int i = 0; i < inner_tile_rows; i++)
212     {
213       for (int j = 0; j < inner_tile_cols; j++)
214       {
215         x[i][j] = *(x_ptrs[i][j]++);
216       }
217     }
218 
219     // Compute XT . x
220     for (int j = 0; j < inner_tile_cols; j++)
221     {
222       XTx[0][j] = x[0][j] - x[2][j];
223       XTx[1][j] = x[1][j] + x[2][j];
224       XTx[2][j] = x[2][j] - x[1][j];
225       XTx[3][j] = x[1][j] - x[3][j];
226     }
227 
228     // Compute U = XT . x . X
229     for (int i = 0; i < inner_tile_rows; i++)
230     {
231       U[i][0] = XTx[i][0] - XTx[i][2];
232       U[i][1] = XTx[i][1] + XTx[i][2];
233       U[i][2] = XTx[i][2] - XTx[i][1];
234       U[i][3] = XTx[i][1] - XTx[i][3];
235     }
236 
237     // Store the transformed matrix
238     for (int i = 0, m = 0; i < inner_tile_rows; i++)
239     {
240       for (int j = 0; j < inner_tile_cols; j++, m++)
241       {
242         *(outptr + m*matrix_stride) = U[i][j];
243       }
244     }
245     outptr++;
246   }
247 }
248 
249 }  // namespace input_transform
250 }  // namespace winograd
251 }  // namespace arm_conv
252