1 /* 2 * Copyright (c) 2017-2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include "convolution_parameters.hpp" 27 #include "ndrange.hpp" 28 29 #include <cstddef> 30 31 namespace arm_gemm 32 { 33 // Avoid circular dependency with arm_gemm.hpp 34 struct GemmConfig; 35 36 // Abstract class for the GEMM/GEMV functions. 37 // 38 // GEMM implementations may be "native" (never require any input 39 // permutation), "pretransposed" (require permutation up-front) or require 40 // working space (permute as they go along). This interface should support 41 // all of them. 42 43 // The real GemmCommon class is templated based on the operand and return 44 // type. This is an interface class which is independent of those types. 45 class IGemmCommon 46 { 47 public: 48 /* Pass in the pointers to the arrays to be operated on and their 49 * strides. This "generic" version uses void *s, the preferred version 50 * is the one provided by templated GemmCommon (below) which takes 51 * appropriately typed pointers. If B is pretransposed (see below) then 52 * the settings for B here are ignored. 53 */ 54 virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, 55 const void *B, const int ldb, /* batches share B */ const int B_multi_stride, 56 void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 57 const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; 58 59 /** @returns an ndrange containing ranges of the compute space which can be 60 * broken up and parallelised over 61 */ 62 virtual ndrange_t get_window_size() const = 0; 63 64 /* The maximum thread count is specified when the GEMM is created. Some 65 * implementations need to know how many threads will actually run in 66 * order to work properly. 67 * 68 * In some cases, after creating the GEMM the number of threads needs to 69 * be reduced (e.g. not enough work to split across threads). This 70 * method allows the number of actual threads to be run to be set (must 71 * be equal or lower). 72 * 73 * This has an empty default implementation, as GEMMs which don't care 74 * about thread count can safely ignore this. 75 */ set_nthreads(int)76 virtual void set_nthreads(int) {}; 77 78 /* Whether this GEMM can be dynamically scheduled or not. */ supports_dynamic_scheduling() const79 virtual bool supports_dynamic_scheduling() const 80 { 81 return false; 82 } 83 84 /** Main execute member fucntion 85 * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size() 86 * @param [in] thread_locator where are we inside of the thread space 87 * @param [in] threadid a unique threadid 88 */ 89 virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0; 90 91 /*** Working space interface (optional) ***/ 92 /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ get_working_size() const93 virtual size_t get_working_size() const 94 { 95 return 0; 96 } 97 /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */ set_working_space(void *)98 virtual void set_working_space(void *) {}; 99 100 /*** "Pretransposed" interface (optional) ***/ 101 /* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */ B_is_pretransposed() const102 virtual bool B_is_pretransposed() const 103 { 104 return false; 105 } 106 /* Does pretranspose still need to be done? */ B_pretranspose_required() const107 virtual bool B_pretranspose_required() const 108 { 109 return false; 110 } 111 /* Total number of bytes of space needed for pretransposed arrays. */ get_B_pretransposed_array_size() const112 virtual size_t get_B_pretransposed_array_size() const 113 { 114 return 0; 115 } 116 /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ 117 /* The "real" version of this depends on the templated operand type (see below). */ 118 virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; 119 /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ set_pretransposed_B_data(void *)120 virtual void set_pretransposed_B_data(void *) 121 { 122 } 123 124 /*** "Quantized bias" interface (optional) ***/ 125 /* Set the bias vector for quantized GEMMs */ set_quantized_bias(const int32_t *,size_t)126 virtual void set_quantized_bias(const int32_t *, size_t) 127 { 128 } 129 130 /*** Indirect interface (optional) ***/ 131 /* Set the indirect table. This comprises a number of values per kernel point, and a densely packed array of pointers, 132 * multis * batches * kernel_points */ set_indirect_parameters_generic(size_t,const void * const * const *)133 virtual void set_indirect_parameters_generic(size_t, const void *const *const *) 134 { 135 } 136 137 /*** Convolution interface (optional) ***/ 138 /* Set the convolution parameters. */ set_convolution_parameters(ConvolutionParameters)139 virtual void set_convolution_parameters(ConvolutionParameters) 140 { 141 } 142 143 /*** Introspection interface ***/ 144 /* Get the configuration of this GEMM */ 145 virtual GemmConfig get_config() = 0; 146 147 // Destructor ~IGemmCommon()148 virtual ~IGemmCommon() 149 { 150 } 151 }; 152 153 /* "Real" GemmCommon class which is templated on the operand and return types. 154 * 155 * In addition to correctly typed versions of the functions that operate on 156 * operand and return data, this class provides a default implementation of 157 * 'set_arrays' to capture the provided arguments in protected class 158 * members, as essentially any implementation will need these. 159 */ 160 template <typename To, typename Tr> 161 class GemmCommon : public IGemmCommon 162 { 163 protected: 164 const To *_Aptr = nullptr; 165 int _lda = 0; 166 int _A_batch_stride = 0; 167 int _A_multi_stride = 0; 168 const To *_Bptr = nullptr; 169 int _ldb = 0; 170 int _B_multi_stride = 0; 171 Tr *_Cptr = nullptr; 172 int _ldc = 0; 173 int _C_batch_stride = 0; 174 int _C_multi_stride = 0; 175 const Tr *_bias = nullptr; 176 int _bias_multi_stride = 0; 177 178 public: 179 /* Pass in the pointers to the arrays to be operated on and their 180 * strides (templated version with appropriate types). */ set_arrays(const To * A,const int lda,const int A_batch_stride,const int A_multi_stride,const To * B,const int ldb,const int B_multi_stride,Tr * C,const int ldc,const int C_batch_stride,const int C_multi_stride,const Tr * bias,const int bias_multi_stride)181 virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, 182 const To *B, const int ldb, /* batches share B */ const int B_multi_stride, 183 Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 184 const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) 185 { 186 _Aptr = A; 187 _lda = lda; 188 _A_batch_stride = A_batch_stride; 189 _A_multi_stride = A_multi_stride; 190 _Bptr = B; 191 _ldb = ldb; 192 _B_multi_stride = B_multi_stride; 193 _Cptr = C; 194 _ldc = ldc; 195 _C_batch_stride = C_batch_stride; 196 _C_multi_stride = C_multi_stride; 197 _bias = bias; 198 _bias_multi_stride = bias_multi_stride; 199 } 200 201 /* Implementation of the void * overload which casts its arguments to the appropriate type. */ set_arrays_generic(const void * A,const int lda,const int A_batch_stride,const int A_multi_stride,const void * B,const int ldb,const int B_multi_stride,void * C,const int ldc,const int C_batch_stride,const int C_multi_stride,const void * bias,const int bias_multi_stride)202 void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, 203 const void *B, const int ldb, /* batches share B */ const int B_multi_stride, 204 void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 205 const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override 206 { 207 set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, 208 static_cast<const To *>(B), ldb, B_multi_stride, 209 static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, 210 static_cast<const Tr *>(bias), bias_multi_stride); 211 } 212 213 /*** "Pretransposed" interface ***/ 214 215 /* Compute col sums over all columns */ requantize_bias(void *,const To *,const int,const int)216 virtual void requantize_bias(void *, const To *, const int, const int) {}; 217 218 /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ 219 /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ pretranspose_B_array(void *,const To *,const int,const int)220 virtual void pretranspose_B_array(void *, const To *, const int, const int) {}; 221 222 /* Implementation of the void * overload which casts its arguments to the appropriate type. */ pretranspose_B_array_generic(void * out,const void * in,const int row_stride,const int multi_stride)223 void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override 224 { 225 pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride); 226 } 227 228 /*** Indirect interface ***/ set_indirect_parameters(size_t,const To * const * const *)229 virtual void set_indirect_parameters(size_t, const To *const *const *) 230 { 231 } 232 set_indirect_parameters_generic(size_t sz,const void * const * const * ptr)233 void set_indirect_parameters_generic(size_t sz, const void *const *const *ptr) override 234 { 235 set_indirect_parameters(sz, reinterpret_cast<const To *const *const *>(ptr)); 236 } 237 }; 238 239 } // namespace arm_gemm 240