1 /* 2 * Copyright (c) 2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 25 #pragma once 26 27 #include "src/core/NEON/kernels/arm_gemm/utils.hpp" 28 #include "interleaves/generic.hpp" 29 #include "depthfirst_driver.hpp" 30 31 namespace arm_conv { 32 namespace depthwise { 33 34 class DepthfirstStrategyUntyped : public IDepthfirstStrategy 35 { 36 public: 37 virtual arm_gemm::VLType get_vl_type() const = 0; 38 39 virtual unsigned int get_kernel_rows() const = 0; 40 virtual unsigned int get_kernel_cols() const = 0; 41 42 virtual unsigned int get_stride_rows() const = 0; 43 virtual unsigned int get_stride_cols() const = 0; 44 45 virtual unsigned int get_input_rows() const override; 46 virtual unsigned int get_input_cols() const override; 47 48 virtual unsigned int get_n_input_points() const; 49 virtual unsigned int get_n_output_points() const; 50 virtual unsigned int get_n_kernel_points() const; 51 52 // Get the number of VLs used in the accumulator, this defaults to 1. 53 virtual unsigned int get_accumulator_depth_vl() const; 54 55 // Get the order in which to pack the weights, this defaults to a row-major 56 // sweep over the weight tensor. 57 virtual bool get_kernel_packing_point(const unsigned int index, unsigned int &x, unsigned int &y) const; 58 }; 59 60 template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage> 61 class DepthfirstStrategy : public DepthfirstStrategyUntyped 62 { 63 public: get_storage_size(const DepthwiseArgs & args) const64 virtual size_t get_storage_size(const DepthwiseArgs &args) const 65 { 66 interleaves::PackingArguments packing_args( 67 this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight), 68 true, sizeof(TAccum), 69 this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(), 70 [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool 71 { return this->get_kernel_packing_point(idx, x, y); } 72 ); 73 return interleaves::get_storage_size_generic(packing_args, args); 74 } 75 pack_parameters(const DepthwiseArgs & args,void * buffer,const void * biases,const OutputStage &,const void * weights,size_t ld_weight_col,size_t ld_weight_row) const76 virtual void pack_parameters( 77 const DepthwiseArgs &args, void *buffer, 78 const void *biases, const OutputStage &, 79 const void *weights, size_t ld_weight_col, size_t ld_weight_row 80 ) const 81 { 82 interleaves::PackingArguments packing_args( 83 this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight), 84 true, sizeof(TAccum), 85 this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(), 86 [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool 87 { return this->get_kernel_packing_point(idx, x, y); } 88 ); 89 interleaves::pack_parameters_generic( 90 packing_args, args, buffer, biases, weights, ld_weight_col, ld_weight_row); 91 } 92 }; 93 94 } // namespace depthwise 95 } // namespace arm_conv 96