1 /* 2 * Copyright (c) 2021-2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 25 #pragma once 26 27 #include "arm_gemm.hpp" 28 #include "arm_gemm_local.hpp" 29 #include "depthwise_common.hpp" 30 31 namespace arm_conv 32 { 33 namespace depthwise 34 { 35 struct DepthwiseConfig 36 { 37 DepthwiseMethod method = DepthwiseMethod::DEFAULT; 38 std::string filter = ""; 39 DepthwiseConfigarm_conv::depthwise::DepthwiseConfig40 DepthwiseConfig(DepthwiseMethod method) 41 : method(method) {}; DepthwiseConfigarm_conv::depthwise::DepthwiseConfig42 DepthwiseConfig() {}; 43 }; 44 45 struct DepthwiseArgs 46 { 47 const CPUInfo *cpu_info; 48 49 unsigned int kernel_rows, kernel_cols; 50 unsigned int stride_rows, stride_cols; 51 52 unsigned int n_batches, input_rows, input_cols, input_channels; 53 unsigned int output_rows, output_cols; 54 unsigned int channel_multiplier; 55 56 PaddingValues padding; 57 58 arm_gemm::Activation activation; 59 60 const DepthwiseConfig *config; 61 62 bool fast_mode = false; 63 DepthwiseArgsarm_conv::depthwise::DepthwiseArgs64 DepthwiseArgs( 65 const CPUInfo *cpu_info, 66 unsigned int kernel_rows, unsigned int kernel_cols, 67 unsigned int stride_rows, unsigned int stride_cols, 68 unsigned int n_batches, unsigned int input_rows, unsigned int input_cols, 69 unsigned int input_channels, 70 unsigned int output_rows, unsigned int output_cols, 71 unsigned int channel_multiplier, 72 PaddingValues padding, arm_gemm::Activation activation, 73 const DepthwiseConfig *config) 74 : cpu_info(cpu_info), kernel_rows(kernel_rows), kernel_cols(kernel_cols), stride_rows(stride_rows), stride_cols(stride_cols), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols), 75 input_channels(input_channels), output_rows(output_rows), output_cols(output_cols), channel_multiplier(channel_multiplier), padding(padding), activation(activation), config(config) 76 { 77 } 78 }; 79 80 template <typename TInput, typename TWeight, typename TOutput> 81 class DepthwiseCommon : public IDepthwiseCommon 82 { 83 private: 84 std::string _name{}; 85 86 protected: 87 const DepthwiseArgs m_args; // Copy of arguments 88 89 public: name() const90 std::string name() const 91 { 92 return _name; 93 } 94 set_name(const std::string & n)95 void set_name(const std::string &n) 96 { 97 _name = n; 98 } 99 DepthwiseCommon(const DepthwiseArgs & args)100 DepthwiseCommon(const DepthwiseArgs &args) 101 : m_args(args) {}; 102 DepthwiseCommon(DepthwiseCommon &) = delete; 103 DepthwiseCommon &operator=(DepthwiseCommon &) = delete; 104 execute(const void * const input,const void * const parameters,void * const output,void * const working_space,const unsigned int thread_id,const unsigned int n_threads) const105 void execute( 106 const void *const input, 107 const void *const parameters, 108 void *const output, 109 void *const working_space, 110 const unsigned int thread_id, 111 const unsigned int n_threads) const override final 112 { 113 const size_t ld_input_col = m_args.input_channels; 114 const size_t ld_input_row = ld_input_col * m_args.input_cols; 115 const size_t ld_input_batch = ld_input_row * m_args.input_rows; 116 const size_t ld_output_col = m_args.input_channels * m_args.channel_multiplier; 117 const size_t ld_output_row = ld_output_col * m_args.output_cols; 118 const size_t ld_output_batch = ld_output_row * m_args.output_rows; 119 120 execute( 121 input, ld_input_col, ld_input_row, ld_input_batch, 122 parameters, output, ld_output_col, ld_output_row, ld_output_batch, 123 working_space, thread_id, n_threads); 124 } 125 execute(const void * const input,size_t ld_input_col,size_t ld_input_row,size_t ld_input_batch,const void * const parameters,void * const output,size_t ld_output_col,size_t ld_output_row,size_t ld_output_batch,void * const working_space,const unsigned int thread_id,const unsigned int n_threads) const126 void execute( 127 const void *const input, 128 size_t ld_input_col, 129 size_t ld_input_row, 130 size_t ld_input_batch, 131 const void *const parameters, 132 void *const output, 133 size_t ld_output_col, 134 size_t ld_output_row, 135 size_t ld_output_batch, 136 void *const working_space, 137 const unsigned int thread_id, 138 const unsigned int n_threads) const override final 139 { 140 execute( 141 m_args.n_batches, m_args.input_rows, m_args.input_cols, 142 m_args.input_channels, m_args.padding, 143 input, ld_input_col, ld_input_row, ld_input_batch, 144 parameters, 145 m_args.output_rows, m_args.output_cols, 146 output, ld_output_col, ld_output_row, ld_output_batch, 147 working_space, thread_id, n_threads); 148 } 149 execute(unsigned int batches,unsigned int input_height,unsigned int input_width,unsigned int channels,const PaddingValues & padding,const void * input,size_t ld_input_col,size_t ld_input_row,size_t ld_input_batch,const void * parameters,unsigned int output_height,unsigned int output_width,void * output,size_t ld_output_col,size_t ld_output_row,size_t ld_output_batch,void * working_space,unsigned int thread_id,unsigned int n_threads) const150 void execute( 151 unsigned int batches, 152 unsigned int input_height, 153 unsigned int input_width, 154 unsigned int channels, 155 const PaddingValues &padding, 156 const void *input, 157 size_t ld_input_col, 158 size_t ld_input_row, 159 size_t ld_input_batch, 160 const void *parameters, 161 unsigned int output_height, 162 unsigned int output_width, 163 void *output, 164 size_t ld_output_col, 165 size_t ld_output_row, 166 size_t ld_output_batch, 167 void *working_space, 168 unsigned int thread_id, 169 unsigned int n_threads) const override final 170 { 171 this->execute_internal( 172 batches, input_height, input_width, channels, padding, input, 173 ld_input_col, ld_input_row, ld_input_batch, parameters, output_height, 174 output_width, output, ld_output_col, ld_output_row, ld_output_batch, 175 working_space, thread_id, n_threads); 176 } 177 178 protected: 179 virtual void execute_internal( 180 unsigned int batches, 181 unsigned int input_height, 182 unsigned int input_width, 183 unsigned int channels, 184 const PaddingValues &, 185 const void *input, 186 size_t ld_input_col, 187 size_t ld_input_row, 188 size_t ld_input_batch, 189 const void *parameters, 190 unsigned int output_height, 191 unsigned int output_width, 192 void *output, 193 size_t ld_output_col, 194 size_t ld_output_row, 195 size_t ld_output_batch, 196 void *working_space, 197 unsigned int thread_id, 198 unsigned int n_threads) const = 0; 199 }; 200 201 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput> 202 using UniqueDepthwiseCommon = std::unique_ptr<DepthwiseCommon<TInput, TWeight, TOutput>>; 203 204 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing> 205 KernelDescription get_depthwise_method(const DepthwiseArgs &, const OutputStage & = {}); 206 207 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing> 208 UniqueDepthwiseCommon<TInput, TWeight, TOutput> depthwise(const DepthwiseArgs &, const OutputStage & = {}); 209 210 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing> 211 std::vector<KernelDescription> get_compatible_kernels(const DepthwiseArgs &, const OutputStage & = {}); 212 213 } // namespace depthwise 214 } // namespace arm_conv 215