1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
27 #include "src/core/NEON/kernels/assembly/winograd.hpp"
28 #include <memory>
29 #include <string>
30 
31 namespace arm_conv {
32 namespace winograd {
33 
34 enum class MethodConstraints
35 {
36   None,
37   RequiresSVE  = 0x1,
38   RequiresSVE2 = 0x2,
39   RequiresSME  = 0x4,
40   RequiresSME2 = 0x8,
41   LargerShape  = 0x10, // Input tensor shape is larger than the output transform tile shape.
42 };
43 
operator !(const MethodConstraints & c)44 constexpr inline bool operator!(const MethodConstraints &c)
45 {
46   return c == MethodConstraints::None;
47 }
48 
operator |(const MethodConstraints & a,const MethodConstraints & b)49 constexpr inline MethodConstraints operator|(const MethodConstraints &a, const MethodConstraints &b)
50 {
51   return static_cast<MethodConstraints>(static_cast<unsigned int>(a) | static_cast<unsigned int>(b));
52 }
53 
operator &(const MethodConstraints & a,const MethodConstraints & b)54 constexpr inline MethodConstraints operator&(const MethodConstraints &a, const MethodConstraints &b)
55 {
56   return static_cast<MethodConstraints>(static_cast<unsigned int>(a) & static_cast<unsigned int>(b));
57 }
58 
constraints_met(const MethodConstraints & c,const CPUInfo * ci,const ConvolutionArgs &,const WinogradConfig *)59 inline bool constraints_met(const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &, const WinogradConfig *)
60 {
61   return (
62     (!(c & MethodConstraints::RequiresSVE) || (ci->has_sve())) &&
63     (!(c & MethodConstraints::RequiresSVE2) || (ci->has_sve2())) &&
64     (!(c & MethodConstraints::RequiresSME) || (ci->has_sme())) &&
65     (!(c & MethodConstraints::RequiresSME2) || (ci->has_sme2()))
66     // Add further constraints here
67   );
68 }
69 
output_transform_constraints_met(const output_transform::ITransform * transform,const MethodConstraints & c,const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)70 inline bool output_transform_constraints_met(const output_transform::ITransform *transform, const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg)
71 {
72   return (
73     constraints_met(c, ci, conv_args, cfg) &&
74     (!(c & MethodConstraints::LargerShape) || (conv_args.input_shape.rows > transform->get_output_rows() && conv_args.input_shape.cols > transform->get_output_cols()))
75   );
76 }
77 
78 namespace weight_transform {
79 
80 template <typename TIn, typename TOut=TIn>
81 struct TransformImplementation
82 {
83   std::unique_ptr<const ITransform> transform;
84   MethodConstraints constraints;
85 
TransformImplementationarm_conv::winograd::weight_transform::TransformImplementation86   TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
87   : transform(transform), constraints(constraints)
88   {
89   }
90 };
91 
92 template <typename TIn, typename TOut=TIn>
93 const TransformImplementation<TIn, TOut> *implementation_list(void);
94 
95 }  // namespace weight_transform
96 
97 namespace input_transform
98 {
99 
100 template <typename TIn, typename TOut=TIn>
101 struct TransformImplementation
102 {
103   std::unique_ptr<const ITransform> transform;
104   MethodConstraints constraints;
105 
TransformImplementationarm_conv::winograd::input_transform::TransformImplementation106   TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
107   : transform(transform), constraints(constraints)
108   {
109   }
110 };
111 
112 template <typename TIn, typename TOut=TIn>
113 const TransformImplementation<TIn, TOut> *implementation_list(void);
114 
115 }  // namespace input_transform
116 
117 namespace output_transform
118 {
119 
120 template <typename TIn, typename TOut=TIn>
121 struct TransformImplementation
122 {
123   std::unique_ptr<const ITransform> transform;
124   MethodConstraints constraints;
125 
TransformImplementationarm_conv::winograd::output_transform::TransformImplementation126   TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None)
127   : transform(transform), constraints(constraints)
128   {
129   }
130 };
131 
132 template <typename TIn, typename TOut=TIn>
133 const TransformImplementation<TIn, TOut> *implementation_list(void);
134 
135 }  // namespace output_transform
136 
137 namespace{
138 
139 template <typename T>
iceildiv(T num,T den)140 constexpr T iceildiv(T num, T den)
141 {
142   return (num + den - 1) / den;
143 }
144 
145 template <typename T>
iroundup(T num,T den)146 constexpr T iroundup(T num, T den)
147 {
148   return den * iceildiv(num, den);
149 }
150 
151 }
152 
153 template <typename TWeight, typename TWinogradIn>
get_weight_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)154 inline std::vector<const weight_transform::ITransform *> get_weight_transforms(
155   const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
156 )
157 {
158   // Get target inner tile size
159   const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
160   const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);
161 
162   std::vector<const weight_transform::ITransform *> weight_transforms;
163   for (auto impl = weight_transform::implementation_list<TWeight, TWinogradIn>();
164        impl->transform.get() != nullptr; impl++)
165   {
166     // If this transform supports the requested kernel size, then add it to the
167     // list of weight transforms.
168     if (
169       constraints_met(impl->constraints, ci, conv_args,  cfg) &&
170       impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
171       impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
172       (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_transformed_tile_rows()) &&
173       (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_transformed_tile_cols()) &&
174       (cfg->weight_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->weight_transform_filter.c_str()))
175     )
176     {
177       weight_transforms.push_back(impl->transform.get());
178     }
179   }
180 
181   return weight_transforms;
182 }
183 
184 template <typename TIn, typename TWinogradIn>
get_input_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)185 inline std::vector<const input_transform::ITransform *> get_input_transforms(
186   const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
187 )
188 {
189   // Get target inner tile size
190   const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1);
191   const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1);
192 
193   std::vector<const input_transform::ITransform *> input_transforms;
194   for (auto impl = input_transform::implementation_list<TIn, TWinogradIn>();
195        impl->transform.get() != nullptr; impl++)
196   {
197     if(
198       constraints_met(impl->constraints, ci, conv_args,  cfg) &&
199       (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_input_rows()) &&
200       (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_input_cols()) &&
201       (cfg->input_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->input_transform_filter.c_str()))
202     )
203     {
204       input_transforms.push_back(impl->transform.get());
205     }
206   }
207 
208   return input_transforms;
209 }
210 
211 template <typename TWinogradOut, typename TOut>
get_output_transforms(const CPUInfo * ci,const ConvolutionArgs & conv_args,const WinogradConfig * cfg)212 inline std::vector<const output_transform::ITransform *> get_output_transforms(
213   const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg
214 )
215 {
216   std::vector<const output_transform::ITransform *> output_transforms;
217   for (auto impl = output_transform::implementation_list<TWinogradOut, TOut>();
218        impl->transform.get() != nullptr; impl++)
219   {
220     if(
221       output_transform_constraints_met(impl->transform.get(), impl->constraints, ci, conv_args,  cfg) &&
222       impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows &&
223       impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols &&
224       (cfg->output_rows == 0 || cfg->output_rows == impl->transform->get_output_rows()) &&
225       (cfg->output_cols == 0 || cfg->output_cols == impl->transform->get_output_cols()) &&
226       (cfg->output_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->output_transform_filter.c_str()))
227     )
228     {
229       output_transforms.push_back(impl->transform.get());
230     }
231   }
232 
233   return output_transforms;
234 }
235 
236 template <typename TIn, typename TWeight, typename TOut, typename TWinogradIn, typename TWinogradOut>
get_implementation(WinogradImpl & dest,const CPUInfo * ci,const ConvolutionArgs & conv_args,int max_threads,bool fast_mode,const WinogradConfig * cfg,const arm_gemm::GemmConfig * gemm_cfg)237 bool get_implementation(
238   WinogradImpl &dest,  // Destination for the selected implementation
239   const CPUInfo *ci,
240   const ConvolutionArgs &conv_args,
241   int max_threads,
242   bool fast_mode,
243   const WinogradConfig *cfg,
244   const arm_gemm::GemmConfig *gemm_cfg
245 )
246 {
247   // Get vectors of valid weight, input and output transforms; then select the
248   // combination which produces the biggest output tile.
249   const auto weight_transforms = get_weight_transforms<TWeight, TWinogradIn>(ci, conv_args, cfg);
250   const auto input_transforms = get_input_transforms<TIn, TWinogradIn>(ci, conv_args, cfg);
251   const auto output_transforms = get_output_transforms<TWinogradOut, TOut>(ci, conv_args, cfg);
252 
253   // Now attempt to select a complete set of Winograd transformations which can
254   // solve the problem. Work backwards from the output transform to find
255   // matching input implementations.
256   bool success = false;
257   for (auto output_transform = output_transforms.cbegin();
258        !success && output_transform != output_transforms.cend();
259        output_transform++)
260   {
261     // Look for matching weight transforms, if we find one then we look for
262     // matching input transforms.
263     for (auto weight_transform = weight_transforms.cbegin();
264          !success && weight_transform != weight_transforms.cend();
265          weight_transform++)
266     {
267       // If this weight transform is compatible, then look for a matching input
268       // transform
269       if ((*output_transform)->get_input_rows() == (*weight_transform)->get_transformed_tile_rows() &&
270           (*output_transform)->get_input_cols() == (*weight_transform)->get_transformed_tile_cols())
271       {
272         for (auto input_transform = input_transforms.cbegin();
273              !success && input_transform != input_transforms.cend();
274              input_transform++)
275         {
276           // If the input transform is suitable, then set the configuration and
277           // indicate success.
278           if ((*input_transform)->get_input_rows() == (*output_transform)->get_input_rows() &&
279               (*input_transform)->get_input_cols() == (*output_transform)->get_input_cols())
280           {
281             dest.output_transform = *output_transform;
282             dest.input_transform = *input_transform;
283             dest.weight_transform = *weight_transform;
284             success = true;
285           }
286         }
287       }
288     }
289   }
290 
291   if (!success)
292   {
293     return false;
294   }
295 
296   // If we're able to construct the Winograd elements, then specify the GEMM
297   // arguments required to perform the multiply-accumulate step of the
298   // convolution.
299   const auto n_output_row_tiles = iceildiv(conv_args.output_shape.rows, dest.output_transform->get_output_rows());
300   const auto n_output_col_tiles = iceildiv(conv_args.output_shape.cols, dest.output_transform->get_output_cols());
301   const auto n_output_patches = n_output_row_tiles * n_output_col_tiles;
302 
303   const int n_multis = dest.input_transform->get_input_rows() *
304                        dest.input_transform->get_input_cols();
305 
306   dest.gemm_args.reset(new arm_gemm::GemmArgs(
307     ci,
308     n_output_patches,  // M
309     conv_args.n_output_channels,  // N
310     conv_args.n_input_channels,  // K
311     1,  // K-sections
312     conv_args.n_batches,  // # Batches
313     n_multis,
314     false,  // Indirect input
315     {},  // No activation
316     max_threads,
317     fast_mode,
318     gemm_cfg
319   ));
320 
321   // Also provide hints for the Winograd memory layout
322   auto &ws = dest.winograd_spec;
323   ws.weight_ld_row = iroundup(conv_args.n_output_channels, 4u);
324   ws.weight_ld_matrix = conv_args.n_input_channels * ws.weight_ld_row;
325   ws.weight_matrix_size_bytes = n_multis * ws.weight_ld_matrix * sizeof(TWinogradIn);
326 
327   ws.input_ld_row = iroundup(conv_args.n_input_channels, 4u);
328   ws.input_ld_matrix = iroundup(n_output_patches, 4u) * ws.input_ld_row;
329   ws.input_ld_batch = n_multis * ws.input_ld_matrix;
330   ws.input_matrix_size_bytes = conv_args.n_batches * ws.input_ld_batch * sizeof(TWinogradIn);
331 
332   ws.output_ld_row = ws.weight_ld_row;
333   ws.output_ld_matrix = n_output_patches * ws.output_ld_row;
334   ws.output_ld_batch = n_multis * ws.output_ld_matrix;
335   ws.output_matrix_size_bytes = conv_args.n_batches * ws.output_ld_batch * sizeof(TWinogradOut);
336 
337   return true;
338 }
339 
340 }  // namespace winograd
341 }  // namespace arm_conv
342