xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/assembly/depthwise.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2021-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
27 #include "arm_gemm.hpp"
28 #include "arm_gemm_local.hpp"
29 #include "depthwise_common.hpp"
30 
31 namespace arm_conv
32 {
33 namespace depthwise
34 {
35 struct DepthwiseConfig
36 {
37     DepthwiseMethod method = DepthwiseMethod::DEFAULT;
38     std::string     filter = "";
39 
DepthwiseConfigarm_conv::depthwise::DepthwiseConfig40     DepthwiseConfig(DepthwiseMethod method)
41         : method(method) {};
DepthwiseConfigarm_conv::depthwise::DepthwiseConfig42     DepthwiseConfig() {};
43 };
44 
45 struct DepthwiseArgs
46 {
47     const CPUInfo *cpu_info;
48 
49     unsigned int kernel_rows, kernel_cols;
50     unsigned int stride_rows, stride_cols;
51 
52     unsigned int n_batches, input_rows, input_cols, input_channels;
53     unsigned int output_rows, output_cols;
54     unsigned int channel_multiplier;
55 
56     PaddingValues padding;
57 
58     arm_gemm::Activation activation;
59 
60     const DepthwiseConfig *config;
61 
62     bool fast_mode = false;
63 
DepthwiseArgsarm_conv::depthwise::DepthwiseArgs64     DepthwiseArgs(
65         const CPUInfo *cpu_info,
66         unsigned int kernel_rows, unsigned int kernel_cols,
67         unsigned int stride_rows, unsigned int stride_cols,
68         unsigned int n_batches, unsigned int input_rows, unsigned int input_cols,
69         unsigned int input_channels,
70         unsigned int output_rows, unsigned int output_cols,
71         unsigned int  channel_multiplier,
72         PaddingValues padding, arm_gemm::Activation activation,
73         const DepthwiseConfig *config)
74         : cpu_info(cpu_info), kernel_rows(kernel_rows), kernel_cols(kernel_cols), stride_rows(stride_rows), stride_cols(stride_cols), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols),
75           input_channels(input_channels), output_rows(output_rows), output_cols(output_cols), channel_multiplier(channel_multiplier), padding(padding), activation(activation), config(config)
76     {
77     }
78 };
79 
80 template <typename TInput, typename TWeight, typename TOutput>
81 class DepthwiseCommon : public IDepthwiseCommon
82 {
83 private:
84     std::string _name{};
85 
86 protected:
87     const DepthwiseArgs m_args; // Copy of arguments
88 
89 public:
name() const90     std::string name() const
91     {
92         return _name;
93     }
94 
set_name(const std::string & n)95     void set_name(const std::string &n)
96     {
97         _name = n;
98     }
99 
DepthwiseCommon(const DepthwiseArgs & args)100     DepthwiseCommon(const DepthwiseArgs &args)
101         : m_args(args) {};
102     DepthwiseCommon(DepthwiseCommon &) = delete;
103     DepthwiseCommon &operator=(DepthwiseCommon &) = delete;
104 
execute(const void * const input,const void * const parameters,void * const output,void * const working_space,const unsigned int thread_id,const unsigned int n_threads) const105     void execute(
106         const void *const  input,
107         const void *const  parameters,
108         void *const        output,
109         void *const        working_space,
110         const unsigned int thread_id,
111         const unsigned int n_threads) const override final
112     {
113         const size_t ld_input_col    = m_args.input_channels;
114         const size_t ld_input_row    = ld_input_col * m_args.input_cols;
115         const size_t ld_input_batch  = ld_input_row * m_args.input_rows;
116         const size_t ld_output_col   = m_args.input_channels * m_args.channel_multiplier;
117         const size_t ld_output_row   = ld_output_col * m_args.output_cols;
118         const size_t ld_output_batch = ld_output_row * m_args.output_rows;
119 
120         execute(
121             input, ld_input_col, ld_input_row, ld_input_batch,
122             parameters, output, ld_output_col, ld_output_row, ld_output_batch,
123             working_space, thread_id, n_threads);
124     }
125 
execute(const void * const input,size_t ld_input_col,size_t ld_input_row,size_t ld_input_batch,const void * const parameters,void * const output,size_t ld_output_col,size_t ld_output_row,size_t ld_output_batch,void * const working_space,const unsigned int thread_id,const unsigned int n_threads) const126     void execute(
127         const void *const  input,
128         size_t             ld_input_col,
129         size_t             ld_input_row,
130         size_t             ld_input_batch,
131         const void *const  parameters,
132         void *const        output,
133         size_t             ld_output_col,
134         size_t             ld_output_row,
135         size_t             ld_output_batch,
136         void *const        working_space,
137         const unsigned int thread_id,
138         const unsigned int n_threads) const override final
139     {
140         execute(
141             m_args.n_batches, m_args.input_rows, m_args.input_cols,
142             m_args.input_channels, m_args.padding,
143             input, ld_input_col, ld_input_row, ld_input_batch,
144             parameters,
145             m_args.output_rows, m_args.output_cols,
146             output, ld_output_col, ld_output_row, ld_output_batch,
147             working_space, thread_id, n_threads);
148     }
149 
execute(unsigned int batches,unsigned int input_height,unsigned int input_width,unsigned int channels,const PaddingValues & padding,const void * input,size_t ld_input_col,size_t ld_input_row,size_t ld_input_batch,const void * parameters,unsigned int output_height,unsigned int output_width,void * output,size_t ld_output_col,size_t ld_output_row,size_t ld_output_batch,void * working_space,unsigned int thread_id,unsigned int n_threads) const150     void execute(
151         unsigned int         batches,
152         unsigned int         input_height,
153         unsigned int         input_width,
154         unsigned int         channels,
155         const PaddingValues &padding,
156         const void          *input,
157         size_t               ld_input_col,
158         size_t               ld_input_row,
159         size_t               ld_input_batch,
160         const void          *parameters,
161         unsigned int         output_height,
162         unsigned int         output_width,
163         void                *output,
164         size_t               ld_output_col,
165         size_t               ld_output_row,
166         size_t               ld_output_batch,
167         void                *working_space,
168         unsigned int         thread_id,
169         unsigned int         n_threads) const override final
170     {
171         this->execute_internal(
172             batches, input_height, input_width, channels, padding, input,
173             ld_input_col, ld_input_row, ld_input_batch, parameters, output_height,
174             output_width, output, ld_output_col, ld_output_row, ld_output_batch,
175             working_space, thread_id, n_threads);
176     }
177 
178 protected:
179     virtual void execute_internal(
180         unsigned int batches,
181         unsigned int input_height,
182         unsigned int input_width,
183         unsigned int channels,
184         const PaddingValues &,
185         const void *input,
186         size_t       ld_input_col,
187         size_t       ld_input_row,
188         size_t       ld_input_batch,
189         const void *parameters,
190         unsigned int output_height,
191         unsigned int output_width,
192         void        *output,
193         size_t       ld_output_col,
194         size_t       ld_output_row,
195         size_t       ld_output_batch,
196         void        *working_space,
197         unsigned int thread_id,
198         unsigned int n_threads) const = 0;
199 };
200 
201 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput>
202 using UniqueDepthwiseCommon = std::unique_ptr<DepthwiseCommon<TInput, TWeight, TOutput>>;
203 
204 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
205 KernelDescription get_depthwise_method(const DepthwiseArgs &, const OutputStage & = {});
206 
207 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
208 UniqueDepthwiseCommon<TInput, TWeight, TOutput> depthwise(const DepthwiseArgs &, const OutputStage & = {});
209 
210 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
211 std::vector<KernelDescription> get_compatible_kernels(const DepthwiseArgs &, const OutputStage & = {});
212 
213 } // namespace depthwise
214 } // namespace arm_conv
215