1 /* 2 * Copyright (c) 2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 25 #pragma once 26 27 #include "src/core/NEON/kernels/assembly/winograd.hpp" 28 #include <algorithm> 29 #include <functional> 30 31 namespace arm_conv { 32 namespace winograd { 33 namespace weight_transform { 34 35 /* Driver class for the Winograd weight transforms. 36 */ 37 template <typename TIn, typename TOut=TIn> 38 class Transform : public ITransform 39 { 40 using Kernel = std::function<void( 41 unsigned int n_channels, // Number of channels to transform 42 const TIn *inptr, size_t ld_in_row, size_t ld_in_col, 43 TOut *outptr, size_t ld_out_matrix 44 )>; 45 46 const std::string m_name; 47 const unsigned int m_kernel_rows, m_kernel_cols; 48 const unsigned int m_transformed_tile_rows, m_transformed_tile_cols; 49 const Kernel m_kernel; 50 execute_internal(const ConvolutionArgs & args,const TIn * inptr,size_t ld_in_row,size_t ld_in_col,size_t ld_input_channel,TOut * outptr,size_t ld_out_matrix,size_t ld_out_row,unsigned int thread_id,unsigned int n_threads) const51 void execute_internal( 52 const ConvolutionArgs &args, 53 const TIn *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, 54 TOut *outptr, size_t ld_out_matrix, size_t ld_out_row, 55 unsigned int thread_id, unsigned int n_threads 56 ) const 57 { 58 // Stripe groups of input channels over threads, this should reduce false 59 // sharing of the output matrix. 60 constexpr auto n_input_channels_per_thread = 16u; 61 62 // Get the initial offset for the input and output pointers 63 const auto offset = thread_id * n_input_channels_per_thread; 64 inptr += offset * ld_input_channel; 65 outptr += offset * ld_out_row; 66 67 for (auto start_ic = thread_id * n_input_channels_per_thread; 68 start_ic < args.n_input_channels; 69 start_ic += n_threads * n_input_channels_per_thread) 70 { 71 // Now iterate over the input channels assigned to this thread. 72 const auto end_ic = std::min(args.n_input_channels, 73 start_ic + n_input_channels_per_thread); 74 for (auto ic = start_ic; ic < end_ic; ic++) 75 { 76 m_kernel(args.n_output_channels, inptr, ld_in_row, ld_in_col, 77 outptr, ld_out_matrix); 78 inptr += ld_input_channel; 79 outptr += ld_out_row; 80 } 81 82 // Progress the pointers to the account for the work not performed by 83 // this thread. 84 const auto skip = (n_threads - 1) * n_input_channels_per_thread; 85 inptr += skip * ld_input_channel; 86 outptr += skip * ld_out_row; 87 } 88 } 89 90 public: Transform(const std::string & name,unsigned int kernel_rows,unsigned int kernel_cols,unsigned int transformed_tile_rows,unsigned int transformed_tile_cols,const Kernel kernel)91 Transform( 92 const std::string &name, 93 unsigned int kernel_rows, unsigned int kernel_cols, 94 unsigned int transformed_tile_rows, unsigned int transformed_tile_cols, 95 const Kernel kernel 96 ) 97 : m_name(name), 98 m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols), 99 m_transformed_tile_rows(transformed_tile_rows), m_transformed_tile_cols(transformed_tile_cols), 100 m_kernel(kernel) 101 { 102 } 103 get_name(void) const104 const std::string &get_name(void) const override { return m_name; } 105 get_kernel_rows(void) const106 unsigned int get_kernel_rows(void) const override { return m_kernel_rows; } get_kernel_cols(void) const107 unsigned int get_kernel_cols(void) const override { return m_kernel_cols; } 108 get_transformed_tile_rows(void) const109 unsigned int get_transformed_tile_rows(void) const override { return m_transformed_tile_rows; } get_transformed_tile_cols(void) const110 unsigned int get_transformed_tile_cols(void) const override { return m_transformed_tile_cols; } 111 execute(const ConvolutionArgs & args,const void * inptr,size_t ld_in_row,size_t ld_in_col,size_t ld_input_channel,void * outptr,size_t ld_out_matrix,size_t ld_out_row,unsigned int thread_id,unsigned int n_threads) const112 void execute( 113 const ConvolutionArgs &args, 114 const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, 115 void *outptr, size_t ld_out_matrix, size_t ld_out_row, 116 unsigned int thread_id, unsigned int n_threads 117 ) const override 118 { 119 execute_internal( 120 args, 121 reinterpret_cast<const TIn *>(inptr), ld_in_row, ld_in_col, ld_input_channel, 122 reinterpret_cast<TOut *>(outptr), ld_out_matrix, ld_out_row, 123 thread_id, n_threads 124 ); 125 } 126 127 /* Utility method to get a transposed variant of a kernel, this transposed 128 * version simply calls the original kernel with the input row and column 129 * strides swapped. 130 */ get_transposed_kernel(const Kernel & kernel)131 static constexpr Kernel get_transposed_kernel(const Kernel &kernel) 132 { 133 return [kernel] ( 134 const unsigned int n_channels, 135 const TIn *const inptr, const size_t ld_in_row, const size_t ld_in_col, 136 TOut *const outptr, const size_t ld_out 137 ) { 138 kernel(n_channels, inptr, ld_in_col, ld_in_row, outptr, ld_out); 139 }; 140 } 141 }; 142 143 } // namespace weight_transform 144 } // namespace winograd 145 } // namespace arm_conv 146