1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include <algorithm>
26 #include <cstddef>
27 #include <arm_neon.h>
28
29 namespace arm_conv {
30 namespace winograd {
31 namespace output_transform {
32
arm_fp32_2x2_3x3(unsigned int n_channels,const float * inptr,const size_t matrix_stride,const float * bptr,float * outptr,const size_t output_row_stride,const size_t output_col_stride,const float output_min,const float output_max)33 void arm_fp32_2x2_3x3(
34 unsigned int n_channels,
35 const float* inptr,
36 const size_t matrix_stride,
37 const float* bptr,
38 float *outptr,
39 const size_t output_row_stride,
40 const size_t output_col_stride,
41 const float output_min,
42 const float output_max
43 )
44 {
45 constexpr auto output_tile_rows = 2u, output_tile_cols = 2u;
46
47 // For each channel of the output
48 for (; n_channels >= 4; n_channels -= 4)
49 {
50 // Matrices used and computed during this transform
51 float32x4_t F[4][4], FZ[4][2], f[2][2], b;
52
53 // Read a 4x4 tile in the Winograd domain
54 for (auto i = 0u, m = 0u; i < 4; i++)
55 {
56 for (auto j = 0u; j < 4; j++, m++)
57 {
58 F[i][j] = vld1q_f32(inptr + m*matrix_stride);
59 }
60 }
61 inptr += 4;
62
63 // Compute the matrix F Z
64 for (auto i = 0u; i < 4; i++)
65 {
66 // FZ[i][0] = F[i][0] + F[i][1] + F[i][2];
67 FZ[i][0] = vaddq_f32(vaddq_f32(F[i][0], F[i][1]), F[i][2]);
68
69 // FZ[i][1] = F[i][1] - F[i][2] - F[i][3];
70 FZ[i][1] = vsubq_f32(vsubq_f32(F[i][1], F[i][2]), F[i][3]);
71 }
72
73 // Compute the output tile f = ZT F Z
74 for (auto j = 0u; j < 2; j++)
75 {
76 // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j];
77 f[0][j] = vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), FZ[2][j]);
78
79 // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j];
80 f[1][j] = vsubq_f32(vsubq_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
81 }
82
83 // Load the bias vector
84 if (bptr != nullptr)
85 {
86 b = vld1q_f32(bptr);
87 bptr += 4;
88 }
89 else
90 {
91 b = vdupq_n_f32(0.0f);
92 }
93
94 // Write out the output tile
95 for (auto i = 0u; i < output_tile_rows; i++)
96 {
97 for (auto j = 0u; j < output_tile_cols; j++)
98 {
99 const auto y =
100 vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
101 vdupq_n_f32(output_min));
102 vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y);
103 }
104 }
105 outptr += 4;
106 }
107 for (; n_channels >= 2; n_channels -= 2)
108 {
109 // Matrices used and computed during this transform
110 float32x2_t F[4][4], FZ[4][2], f[2][2], b;
111
112 // Read a 4x4 tile in the Winograd domain
113 for (auto i = 0u, m = 0u; i < 4; i++)
114 {
115 for (auto j = 0u; j < 4; j++, m++)
116 {
117 F[i][j] = vld1_f32(inptr + m*matrix_stride);
118 }
119 }
120 inptr += 2;
121
122 // Compute the matrix F Z
123 for (auto i = 0u; i < 4; i++)
124 {
125 // FZ[i][0] = F[i][0] + F[i][1] + F[i][2];
126 FZ[i][0] = vadd_f32(vadd_f32(F[i][0], F[i][1]), F[i][2]);
127
128 // FZ[i][1] = F[i][1] - F[i][2] - F[i][3];
129 FZ[i][1] = vsub_f32(vsub_f32(F[i][1], F[i][2]), F[i][3]);
130 }
131
132 // Compute the output tile f = ZT F Z
133 for (auto j = 0u; j < 2; j++)
134 {
135 // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j];
136 f[0][j] = vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), FZ[2][j]);
137
138 // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j];
139 f[1][j] = vsub_f32(vsub_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
140 }
141
142 // Load the bias vector
143 if (bptr != nullptr)
144 {
145 b = vld1_f32(bptr);
146 bptr += 2;
147 }
148 else
149 {
150 b = vdup_n_f32(0.0f);
151 }
152
153 // Write out the output tile
154 for (auto i = 0u; i < output_tile_rows; i++)
155 {
156 for (auto j = 0u; j < output_tile_cols; j++)
157 {
158 const auto y =
159 vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
160 vdup_n_f32(output_min));
161 vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y);
162 }
163 }
164 outptr += 2;
165 }
166 for (; n_channels; n_channels--)
167 {
168 // Matrices used and computed during this transform
169 float F[4][4], FZ[4][2], f[2][2], b;
170
171 // Read a 4x4 tile in the Winograd domain
172 for (auto i = 0u, m = 0u; i < 4; i++)
173 {
174 for (auto j = 0u; j < 4; j++, m++)
175 {
176 F[i][j] = *(inptr + m*matrix_stride);
177 }
178 }
179 inptr++;
180
181 // Compute the matrix F Z
182 for (auto i = 0u; i < 4; i++)
183 {
184 FZ[i][0] = F[i][0] + F[i][1] + F[i][2];
185 FZ[i][1] = F[i][1] - F[i][2] - F[i][3];
186 }
187
188 // Compute the output tile f = ZT F Z
189 for (auto j = 0u; j < 2; j++)
190 {
191 f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j];
192 f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j];
193 }
194
195 // Load the bias
196 if (bptr != nullptr)
197 {
198 b = *(bptr++);
199 }
200 else
201 {
202 b = 0.0f;
203 }
204
205 // Write out the output tile
206 for (auto i = 0u; i < output_tile_rows; i++)
207 {
208 for (auto j = 0u; j < output_tile_cols; j++)
209 {
210 const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
211 *(outptr + i*output_row_stride + j*output_col_stride) = y;
212 }
213 }
214 outptr++;
215 }
216 }
217
218 } // namespace output_transform
219 } // namespace winograd
220 } // namespace arm_conv
221