xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/utils.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
27 #include "src/cpu/kernels/assembly/arm_gemm.hpp"
28 
29 #include <cstddef>
30 #include <limits>
31 #include <tuple>
32 
33 // Macro for unreachable code (e.g. impossible default cases on switch)
34 #define UNREACHABLE(why)  __builtin_unreachable()
35 
36 // Paranoid option for the above with assert
37 // #define UNREACHABLE(why)   assert(0 && why)
38 
39 namespace arm_gemm {
40 
41 template<typename T>
get_type_name()42 std::string get_type_name() {
43 #ifdef __GNUC__
44     std::string s = __PRETTY_FUNCTION__;
45 
46     auto start = s.find("cls_");
47 
48     if (start==std::string::npos) {
49         return "(unknown)";
50     }
51 
52     for(size_t x = start+4; x<s.size(); x++) {
53         if (s[x] == ';' || s[x] == ']') {
54             return s.substr(start+4, x-(start+4));
55         }
56     }
57 
58     return "(unknown)";
59 #else
60     return "(unsupported)";
61 #endif
62 }
63 
64 template<typename T>
iceildiv(const T a,const T b)65 inline T iceildiv(const T a, const T b) {
66     return (a + b - 1) / b;
67 }
68 
69 template <typename T>
roundup(const T a,const T b)70 inline T roundup(const T a, const T b) {
71     T rem = a % b;
72 
73     if (rem) {
74         return a + b - rem;
75     } else {
76         return a;
77     }
78 }
79 
80 enum class VLType {
81     None,
82     SVE,
83     SME
84 };
85 
86 template<typename T>
87 struct IndirectOutputArg {
88     struct {
89         T       *base;
90         size_t   stride;
91     } direct = {};
92     struct {
93         T * const *ptr;
94         size_t     offset;
95     } indirect = {};
96     bool is_indirect;
97 
98     // Direct
IndirectOutputArgarm_gemm::IndirectOutputArg99     IndirectOutputArg(T *base, size_t stride) : is_indirect(false) {
100         direct.base = base;
101         direct.stride = stride;
102     }
103 
104     // Indirect
IndirectOutputArgarm_gemm::IndirectOutputArg105     IndirectOutputArg(T * const * ptr, size_t offset) : is_indirect(true) {
106         indirect.ptr = ptr;
107         indirect.offset = offset;
108     }
109 
IndirectOutputArgarm_gemm::IndirectOutputArg110     IndirectOutputArg() : is_indirect(false) {
111         direct.base = nullptr;
112         direct.stride = 0;
113     }
114 };
115 
116 // Check that the provided Requantize32 doesn't have a left shift.
quant_no_left_shift(const Requantize32 & qp)117 inline bool quant_no_left_shift(const Requantize32 &qp) {
118     if (qp.per_channel_requant) {
119         return (qp.per_channel_left_shifts == nullptr);
120     } else {
121         return (qp.per_layer_left_shift == 0);
122     }
123 }
124 
125 // Check that the provided Requantize32 is compatible with the "symmetric" hybrid kernels.  These don't include row
126 // sums, so the 'b_offset' has to be zero.
quant_hybrid_symmetric(const Requantize32 & qp)127 inline bool quant_hybrid_symmetric(const Requantize32 &qp) {
128     return quant_no_left_shift(qp) && qp.b_offset == 0;
129 }
130 
131 // Check that the provided Requantize32 is compatible with the "asymmetric" hybrid kernels.  These don't support per
132 // channel quantization.  Technically b_offset==0 cases would work, but it is a waste to sum and then multiply by 0...
quant_hybrid_asymmetric(const Requantize32 & qp)133 inline bool quant_hybrid_asymmetric(const Requantize32 &qp) {
134     return quant_no_left_shift(qp) /*  && qp.b_offset != 0 */ && qp.per_channel_requant==false;
135 }
136 
137 template<typename T>
138 struct IndirectInputArg {
139     struct {
140         const T *base;
141         size_t   stride;
142     } direct = {};
143     struct {
144         const T * const * const * ptr;
145         unsigned int start_row;
146         unsigned int start_col;
147     } indirect = {};
148     bool is_indirect;
149 
150     // Direct
IndirectInputArgarm_gemm::IndirectInputArg151     IndirectInputArg(const T *base, size_t stride) : is_indirect(false) {
152         direct.base = base;
153         direct.stride = stride;
154     }
155 
156     // Indirect
IndirectInputArgarm_gemm::IndirectInputArg157     IndirectInputArg(const T * const * const *ptr, unsigned int start_row, unsigned int start_col) : is_indirect(true) {
158         indirect.ptr = ptr;
159         indirect.start_row = start_row;
160         indirect.start_col = start_col;
161     }
162 
IndirectInputArgarm_gemm::IndirectInputArg163     IndirectInputArg() : is_indirect(false) {
164         direct.base = nullptr;
165         direct.stride = 0;
166     }
167 };
168 
169 namespace utils {
170 
171 // get_vector_length(): Returns SVE vector length for type "T".
172 //
173 // It is required that this can be compiled by a compiler in non-SVE mode, but it must be prevented from running (at
174 // runtime) if SVE is not enabled.  Typically this is used by switchyard/driver code which is built in normal mode
175 // which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime.
176 template <typename T>
get_vector_length()177 inline unsigned long get_vector_length() {
178 #if defined(__aarch64__)
179     uint64_t vl;
180 
181     __asm __volatile (
182         ".inst 0x0420e3e0\n" // CNTB X0, ALL, MUL #1
183         "mov %0, X0\n"
184         : "=r" (vl)
185         :
186         : "x0"
187     );
188 
189     return vl / sizeof(T);
190 #else // !defined(__aarch64__)
191     return 16 / sizeof(T);
192 #endif // defined(__aarch64__)
193 }
194 
195 #ifdef ARM_COMPUTE_ENABLE_SME
196 namespace sme {
197 
198 // function from misc-sve.cpp
199 extern unsigned int raw_vector_length();
200 
201 template <typename T>
get_vector_length()202 inline unsigned long get_vector_length() {
203     return raw_vector_length() / sizeof(T);
204 }
205 
206 } // namespace sme
207 #endif // ARM_COMPUTE_ENABLE_SME
208 
209 // get_vector_length(VLType): Returns vector length for type "T".
210 //
211 // This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
212 
213 template <typename T>
get_vector_length(VLType vl_type)214 inline unsigned long get_vector_length(VLType vl_type) {
215   switch (vl_type) {
216 #ifdef ARM_COMPUTE_ENABLE_SME
217     case VLType::SME:
218       return sme::get_vector_length<T>();
219 #endif // ARM_COMPUTE_ENABLE_SME
220     case VLType::SVE:
221       return get_vector_length<T>();
222     default:
223       return 16 / sizeof(T);
224   }
225 }
226 
227 // get_default_activation_values(): Returns the default values for activation min and max for integer activation.
228 template <typename T>
get_default_activation_values()229 inline std::tuple<T, T> get_default_activation_values()
230 {
231     const T min = static_cast<T>(std::numeric_limits<T>::min());
232     const T max = static_cast<T>(std::numeric_limits<T>::max());
233 
234     return std::make_tuple(min, max);
235 }
236 
237 // get_default_activation_values(): Returns the default values for activation min and max for float activation.
238 template <>
get_default_activation_values()239 inline std::tuple<float, float> get_default_activation_values()
240 {
241     const float min = static_cast<float>(-std::numeric_limits<float>::infinity());
242     const float max = static_cast<float>(std::numeric_limits<float>::infinity());
243 
244     return std::make_tuple(min, max);
245 }
246 
247 #if defined(__ARM_FP16_ARGS)
248 // get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation.
249 template <>
get_default_activation_values()250 inline std::tuple<__fp16, __fp16> get_default_activation_values()
251 {
252     const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity());
253     const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity());
254 
255     return std::make_tuple(min, max);
256 }
257 #endif  // defined(__ARM_FP16_ARGS)
258 } // utils namespace
259 } // arm_gemm namespace
260 
261 using namespace arm_gemm::utils;
262