1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include <algorithm>
26 #include <cstddef>
27 #include <arm_neon.h>
28
29 namespace arm_conv {
30 namespace winograd {
31 namespace output_transform {
32
arm_fp32_2x2_5x5(unsigned int n_channels,const float * inptr,const size_t matrix_stride,const float * bptr,float * outptr,const size_t output_row_stride,const size_t output_col_stride,const float output_min,const float output_max)33 void arm_fp32_2x2_5x5(
34 unsigned int n_channels,
35 const float* inptr,
36 const size_t matrix_stride,
37 const float* bptr,
38 float *outptr,
39 const size_t output_row_stride,
40 const size_t output_col_stride,
41 const float output_min,
42 const float output_max
43 )
44 {
45 constexpr auto output_tile_rows = 2u, output_tile_cols = 2u;
46
47 // For each channel of the output
48 for (; n_channels >= 4; n_channels -= 4)
49 {
50 // Matrices used and computed during this transform
51 float32x4_t F[6][6], FZ[6][2], f[2][2], b;
52
53 // Read a 6x6 tile in the Winograd domain
54 for (auto i = 0u, m = 0u; i < 6; i++)
55 {
56 for (auto j = 0u; j < 6; j++, m++)
57 {
58 F[i][j] = vld1q_f32(inptr + m*matrix_stride);
59 }
60 }
61 inptr += 4;
62
63 // Compute the matrix F Z
64 for (auto i = 0u; i < 6; i++)
65 {
66 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
67 FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
68
69 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
70 FZ[i][1] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f), F[i][5]);
71 }
72
73 // Compute the output tile f = ZT F Z
74 for (auto j = 0u; j < 2; j++)
75 {
76 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
77 f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
78
79 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
80 f[1][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f), FZ[5][j]);
81 }
82
83 // Write out the output tile
84 if (bptr != nullptr)
85 {
86 b = vld1q_f32(bptr);
87 bptr += 4;
88 }
89 else
90 {
91 b = vdupq_n_f32(0.0f);
92 }
93 for (auto i = 0u; i < output_tile_rows; i++)
94 {
95 for (auto j = 0u; j < output_tile_cols; j++)
96 {
97 const auto y =
98 vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
99 vdupq_n_f32(output_min));
100 vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y);
101 }
102 }
103 outptr += 4;
104 }
105 for (; n_channels >= 2; n_channels -= 2)
106 {
107 // Matrices used and computed during this transform
108 float32x2_t F[6][6], FZ[6][2], f[2][2], b;
109
110 // Read a 6x6 tile in the Winograd domain
111 for (auto i = 0u, m = 0u; i < 6; i++)
112 {
113 for (auto j = 0u; j < 6; j++, m++)
114 {
115 F[i][j] = vld1_f32(inptr + m*matrix_stride);
116 }
117 }
118 inptr += 2;
119
120 // Compute the matrix F Z
121 for (auto i = 0u; i < 6; i++)
122 {
123 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
124 FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]);
125
126 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
127 FZ[i][1] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f), F[i][5]);
128 }
129
130 // Compute the output tile f = ZT F Z
131 for (auto j = 0u; j < 2; j++)
132 {
133 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
134 f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
135
136 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
137 f[1][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f), FZ[5][j]);
138 }
139
140 // Write out the output tile
141 if (bptr != nullptr)
142 {
143 b = vld1_f32(bptr);
144 bptr += 2;
145 }
146 else
147 {
148 b = vdup_n_f32(0.0f);
149 }
150 for (auto i = 0u; i < output_tile_rows; i++)
151 {
152 for (auto j = 0u; j < output_tile_cols; j++)
153 {
154 const auto y =
155 vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
156 vdup_n_f32(output_min));
157 vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y);
158 }
159 }
160 outptr += 2;
161 }
162 if (n_channels)
163 {
164 // Matrices used and computed during this transform
165 float F[6][6], FZ[6][2], f[2][2], b;
166
167 // Read a 6x6 tile in the Winograd domain
168 for (auto i = 0u, m = 0u; i < 6; i++)
169 {
170 for (auto j = 0u; j < 6; j++, m++)
171 {
172 F[i][j] = *(inptr + m*matrix_stride);
173 }
174 }
175
176 // Compute the matrix F Z
177 for (auto i = 0u; i < 6; i++)
178 {
179 FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
180 FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
181 }
182
183 // Compute the output tile f = ZT F Z
184 for (auto j = 0u; j < 2; j++)
185 {
186 f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
187 f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
188 }
189
190 // Write out the output tile
191 if (bptr != nullptr)
192 {
193 b = *(bptr++);
194 }
195 else
196 {
197 b = 0.0f;
198 }
199 for (auto i = 0u; i < output_tile_rows; i++)
200 {
201 for (auto j = 0u; j < output_tile_cols; j++)
202 {
203 const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
204 *(outptr + i*output_row_stride + j*output_col_stride) = y;
205 }
206 }
207 }
208 }
209
210 } // namespace output_transform
211 } // namespace winograd
212 } // namespace arm_conv
213