1 /*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #pragma once
26
27 #include "src/core/NEON/kernels/assembly/winograd.hpp"
28 #include <memory>
29 #include <string>
30
31 namespace arm_conv {
32 namespace winograd {
33
34 enum class MethodConstraints
35 {
36 None,
37 RequiresSVE = 0x1,
38 RequiresSVE2 = 0x2,
39 RequiresSME = 0x4,
40 RequiresSME2 = 0x8,
41 LargerShape = 0x10, // Input tensor shape is larger than the output transform tile shape.
42 };
43
operator !(const MethodConstraints & c)44 constexpr inline bool operator!(const MethodConstraints &c)
45 {
46 return c == MethodConstraints::None;
47 }
48
operator |(const MethodConstraints & a,const MethodConstraints & b)49 constexpr inline MethodConstraints operator|(const MethodConstraints &a, const MethodConstraints &b)
50 {
51 return static_cast<MethodConstraints>(static_cast<unsigned int>(a) | static_cast<unsigned int>(b));
52 }
53
operator &(const MethodConstraints & a,const MethodConstraints & b)54 constexpr inline MethodConstraints operator&(const MethodConstraints &a, const MethodConstraints &b)
55 {
56 return static_cast<MethodConstraints>(static_cast<unsigned int>(a) & static_cast<unsigned int>(b));
57 }
58
constraints_met(const MethodConstraints & c,const CPUInfo * ci,const ConvolutionArgs &,const WinogradConfig *)59 inline bool constraints_met(const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &, const WinogradConfig *)
60 {
61 return (
62 (!(c & MethodConstraints::RequiresSVE) || (ci->has_sve())) &&
63 (!(c & MethodConstraints::RequiresSVE2) || (ci->has_sve2())) &&
64 (!(c & MethodConstraints::RequiresSME) || (ci->has_sme())) &&
65 (!(c & MethodConstraints::RequiresSME2) || (ci->has_sme2()))
66 // Add further constraints here
67 );
68 }
69
output_transform_constraints_met(const output_transform::ITransform * transform,const MethodConstraints & c,const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)70 inline bool output_transform_constraints_met(const output_transform::ITransform *transform, const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg)
71 {
72 return (
73 constraints_met(c, ci, conv_args, cfg) &&
74 (!(c & MethodConstraints::LargerShape) || (conv_args.input_shape.rows > transform->get_output_rows() && conv_args.input_shape.cols > transform->get_output_cols()))
75 );
76 }
77
78 namespace weight_transform {
79
80 template <typename TIn, typename TOut=TIn>
81 struct TransformImplementation
82 {
83 std::unique_ptr<const ITransform> transform;
84 MethodConstraints constraints;
85
TransformImplementationarm_conv::winograd::weight_transform::TransformImplementation86 TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
87 : transform(transform), constraints(constraints)
88 {
89 }
90 };
91
92 template <typename TIn, typename TOut=TIn>
93 const TransformImplementation<TIn, TOut> *implementation_list(void);
94
95 } // namespace weight_transform
96
97 namespace input_transform
98 {
99
100 template <typename TIn, typename TOut=TIn>
101 struct TransformImplementation
102 {
103 std::unique_ptr<const ITransform> transform;
104 MethodConstraints constraints;
105
TransformImplementationarm_conv::winograd::input_transform::TransformImplementation106 TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
107 : transform(transform), constraints(constraints)
108 {
109 }
110 };
111
112 template <typename TIn, typename TOut=TIn>
113 const TransformImplementation<TIn, TOut> *implementation_list(void);
114
115 } // namespace input_transform
116
117 namespace output_transform
118 {
119
120 template <typename TIn, typename TOut=TIn>
121 struct TransformImplementation
122 {
123 std::unique_ptr<const ITransform> transform;
124 MethodConstraints constraints;
125
TransformImplementationarm_conv::winograd::output_transform::TransformImplementation126 TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
127 : transform(transform), constraints(constraints)
128 {
129 }
130 };
131
132 template <typename TIn, typename TOut=TIn>
133 const TransformImplementation<TIn, TOut> *implementation_list(void);
134
135 } // namespace output_transform
136
137 namespace{
138
139 template <typename T>
iceildiv(T num,T den)140 constexpr T iceildiv(T num, T den)
141 {
142 return (num + den - 1) / den;
143 }
144
145 template <typename T>
iroundup(T num,T den)146 constexpr T iroundup(T num, T den)
147 {
148 return den * iceildiv(num, den);
149 }
150
151 }
152
153 template <typename TWeight, typename TWinogradIn>
get_weight_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)154 inline std::vector<const weight_transform::ITransform *> get_weight_transforms(
155 const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
156 )
157 {
158 // Get target inner tile size
159 const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
160 const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);
161
162 std::vector<const weight_transform::ITransform *> weight_transforms;
163 for (auto impl = weight_transform::implementation_list<TWeight, TWinogradIn>();
164 impl->transform.get() != nullptr; impl++)
165 {
166 // If this transform supports the requested kernel size, then add it to the
167 // list of weight transforms.
168 if (
169 constraints_met(impl->constraints, ci, conv_args, cfg) &&
170 impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
171 impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
172 (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_transformed_tile_rows()) &&
173 (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_transformed_tile_cols()) &&
174 (cfg->weight_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->weight_transform_filter.c_str()))
175 )
176 {
177 weight_transforms.push_back(impl->transform.get());
178 }
179 }
180
181 return weight_transforms;
182 }
183
184 template <typename TIn, typename TWinogradIn>
get_input_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)185 inline std::vector<const input_transform::ITransform *> get_input_transforms(
186 const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
187 )
188 {
189 // Get target inner tile size
190 const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
191 const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);
192
193 std::vector<const input_transform::ITransform *> input_transforms;
194 for (auto impl = input_transform::implementation_list<TIn, TWinogradIn>();
195 impl->transform.get() != nullptr; impl++)
196 {
197 if(
198 constraints_met(impl->constraints, ci, conv_args, cfg) &&
199 (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_input_rows()) &&
200 (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_input_cols()) &&
201 (cfg->input_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->input_transform_filter.c_str()))
202 )
203 {
204 input_transforms.push_back(impl->transform.get());
205 }
206 }
207
208 return input_transforms;
209 }
210
211 template <typename TWinogradOut, typename TOut>
get_output_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)212 inline std::vector<const output_transform::ITransform *> get_output_transforms(
213 const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
214 )
215 {
216 std::vector<const output_transform::ITransform *> output_transforms;
217 for (auto impl = output_transform::implementation_list<TWinogradOut, TOut>();
218 impl->transform.get() != nullptr; impl++)
219 {
220 if(
221 output_transform_constraints_met(impl->transform.get(), impl->constraints, ci, conv_args, cfg) &&
222 impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
223 impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
224 (cfg->output_rows == 0 || cfg->output_rows == impl->transform->get_output_rows()) &&
225 (cfg->output_cols == 0 || cfg->output_cols == impl->transform->get_output_cols()) &&
226 (cfg->output_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->output_transform_filter.c_str()))
227 )
228 {
229 output_transforms.push_back(impl->transform.get());
230 }
231 }
232
233 return output_transforms;
234 }
235
236 template <typename TIn, typename TWeight, typename TOut, typename TWinogradIn, typename TWinogradOut>
get_implementation(WinogradImpl & dest,const CPUInfo * ci,const ConvolutionArgs & conv_args,int max_threads,bool fast_mode,const WinogradConfig * cfg,const arm_gemm::GemmConfig * gemm_cfg)237 bool get_implementation(
238 WinogradImpl &dest, // Destination for the selected implementation
239 const CPUInfo *ci,
240 const ConvolutionArgs &conv_args,
241 int max_threads,
242 bool fast_mode,
243 const WinogradConfig *cfg,
244 const arm_gemm::GemmConfig *gemm_cfg
245 )
246 {
247 // Get vectors of valid weight, input and output transforms; then select the
248 // combination which produces the biggest output tile.
249 const auto weight_transforms = get_weight_transforms<TWeight, TWinogradIn>(ci, conv_args, cfg);
250 const auto input_transforms = get_input_transforms<TIn, TWinogradIn>(ci, conv_args, cfg);
251 const auto output_transforms = get_output_transforms<TWinogradOut, TOut>(ci, conv_args, cfg);
252
253 // Now attempt to select a complete set of Winograd transformations which can
254 // solve the problem. Work backwards from the output transform to find
255 // matching input implementations.
256 bool success = false;
257 for (auto output_transform = output_transforms.cbegin();
258 !success && output_transform != output_transforms.cend();
259 output_transform++)
260 {
261 // Look for matching weight transforms, if we find one then we look for
262 // matching input transforms.
263 for (auto weight_transform = weight_transforms.cbegin();
264 !success && weight_transform != weight_transforms.cend();
265 weight_transform++)
266 {
267 // If this weight transform is compatible, then look for a matching input
268 // transform
269 if ((*output_transform)->get_input_rows() == (*weight_transform)->get_transformed_tile_rows() &&
270 (*output_transform)->get_input_cols() == (*weight_transform)->get_transformed_tile_cols())
271 {
272 for (auto input_transform = input_transforms.cbegin();
273 !success && input_transform != input_transforms.cend();
274 input_transform++)
275 {
276 // If the input transform is suitable, then set the configuration and
277 // indicate success.
278 if ((*input_transform)->get_input_rows() == (*output_transform)->get_input_rows() &&
279 (*input_transform)->get_input_cols() == (*output_transform)->get_input_cols())
280 {
281 dest.output_transform = *output_transform;
282 dest.input_transform = *input_transform;
283 dest.weight_transform = *weight_transform;
284 success = true;
285 }
286 }
287 }
288 }
289 }
290
291 if (!success)
292 {
293 return false;
294 }
295
296 // If we're able to construct the Winograd elements, then specify the GEMM
297 // arguments required to perform the multiply-accumulate step of the
298 // convolution.
299 const auto n_output_row_tiles = iceildiv(conv_args.output_shape.rows, dest.output_transform->get_output_rows());
300 const auto n_output_col_tiles = iceildiv(conv_args.output_shape.cols, dest.output_transform->get_output_cols());
301 const auto n_output_patches = n_output_row_tiles * n_output_col_tiles;
302
303 const int n_multis = dest.input_transform->get_input_rows() *
304 dest.input_transform->get_input_cols();
305
306 dest.gemm_args.reset(new arm_gemm::GemmArgs(
307 ci,
308 n_output_patches, // M
309 conv_args.n_output_channels, // N
310 conv_args.n_input_channels, // K
311 1, // K-sections
312 conv_args.n_batches, // # Batches
313 n_multis,
314 false, // Indirect input
315 {}, // No activation
316 max_threads,
317 fast_mode,
318 gemm_cfg
319 ));
320
321 // Also provide hints for the Winograd memory layout
322 auto &ws = dest.winograd_spec;
323 ws.weight_ld_row = iroundup(conv_args.n_output_channels, 4u);
324 ws.weight_ld_matrix = conv_args.n_input_channels * ws.weight_ld_row;
325 ws.weight_matrix_size_bytes = n_multis * ws.weight_ld_matrix * sizeof(TWinogradIn);
326
327 ws.input_ld_row = iroundup(conv_args.n_input_channels, 4u);
328 ws.input_ld_matrix = iroundup(n_output_patches, 4u) * ws.input_ld_row;
329 ws.input_ld_batch = n_multis * ws.input_ld_matrix;
330 ws.input_matrix_size_bytes = conv_args.n_batches * ws.input_ld_batch * sizeof(TWinogradIn);
331
332 ws.output_ld_row = ws.weight_ld_row;
333 ws.output_ld_matrix = n_output_patches * ws.output_ld_row;
334 ws.output_ld_batch = n_multis * ws.output_ld_matrix;
335 ws.output_matrix_size_bytes = conv_args.n_batches * ws.output_ld_batch * sizeof(TWinogradOut);
336
337 return true;
338 }
339
340 } // namespace winograd
341 } // namespace arm_conv
342