1 /*
2 * Copyright (c) 2017-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #pragma once
26
27 #include "src/cpu/kernels/assembly/arm_gemm.hpp"
28
29 #include <cstddef>
30 #include <limits>
31 #include <tuple>
32
33 // Macro for unreachable code (e.g. impossible default cases on switch)
34 #define UNREACHABLE(why) __builtin_unreachable()
35
36 // Paranoid option for the above with assert
37 // #define UNREACHABLE(why) assert(0 && why)
38
39 namespace arm_gemm {
40
41 template<typename T>
get_type_name()42 std::string get_type_name() {
43 #ifdef __GNUC__
44 std::string s = __PRETTY_FUNCTION__;
45
46 auto start = s.find("cls_");
47
48 if (start==std::string::npos) {
49 return "(unknown)";
50 }
51
52 for(size_t x = start+4; x<s.size(); x++) {
53 if (s[x] == ';' || s[x] == ']') {
54 return s.substr(start+4, x-(start+4));
55 }
56 }
57
58 return "(unknown)";
59 #else
60 return "(unsupported)";
61 #endif
62 }
63
64 template<typename T>
iceildiv(const T a,const T b)65 inline T iceildiv(const T a, const T b) {
66 return (a + b - 1) / b;
67 }
68
69 template <typename T>
roundup(const T a,const T b)70 inline T roundup(const T a, const T b) {
71 T rem = a % b;
72
73 if (rem) {
74 return a + b - rem;
75 } else {
76 return a;
77 }
78 }
79
80 enum class VLType {
81 None,
82 SVE,
83 SME
84 };
85
86 template<typename T>
87 struct IndirectOutputArg {
88 struct {
89 T *base;
90 size_t stride;
91 } direct = {};
92 struct {
93 T * const *ptr;
94 size_t offset;
95 } indirect = {};
96 bool is_indirect;
97
98 // Direct
IndirectOutputArgarm_gemm::IndirectOutputArg99 IndirectOutputArg(T *base, size_t stride) : is_indirect(false) {
100 direct.base = base;
101 direct.stride = stride;
102 }
103
104 // Indirect
IndirectOutputArgarm_gemm::IndirectOutputArg105 IndirectOutputArg(T * const * ptr, size_t offset) : is_indirect(true) {
106 indirect.ptr = ptr;
107 indirect.offset = offset;
108 }
109
IndirectOutputArgarm_gemm::IndirectOutputArg110 IndirectOutputArg() : is_indirect(false) {
111 direct.base = nullptr;
112 direct.stride = 0;
113 }
114 };
115
116 // Check that the provided Requantize32 doesn't have a left shift.
quant_no_left_shift(const Requantize32 & qp)117 inline bool quant_no_left_shift(const Requantize32 &qp) {
118 if (qp.per_channel_requant) {
119 return (qp.per_channel_left_shifts == nullptr);
120 } else {
121 return (qp.per_layer_left_shift == 0);
122 }
123 }
124
125 // Check that the provided Requantize32 is compatible with the "symmetric" hybrid kernels. These don't include row
126 // sums, so the 'b_offset' has to be zero.
quant_hybrid_symmetric(const Requantize32 & qp)127 inline bool quant_hybrid_symmetric(const Requantize32 &qp) {
128 return quant_no_left_shift(qp) && qp.b_offset == 0;
129 }
130
131 // Check that the provided Requantize32 is compatible with the "asymmetric" hybrid kernels. These don't support per
132 // channel quantization. Technically b_offset==0 cases would work, but it is a waste to sum and then multiply by 0...
quant_hybrid_asymmetric(const Requantize32 & qp)133 inline bool quant_hybrid_asymmetric(const Requantize32 &qp) {
134 return quant_no_left_shift(qp) /* && qp.b_offset != 0 */ && qp.per_channel_requant==false;
135 }
136
137 template<typename T>
138 struct IndirectInputArg {
139 struct {
140 const T *base;
141 size_t stride;
142 } direct = {};
143 struct {
144 const T * const * const * ptr;
145 unsigned int start_row;
146 unsigned int start_col;
147 } indirect = {};
148 bool is_indirect;
149
150 // Direct
IndirectInputArgarm_gemm::IndirectInputArg151 IndirectInputArg(const T *base, size_t stride) : is_indirect(false) {
152 direct.base = base;
153 direct.stride = stride;
154 }
155
156 // Indirect
IndirectInputArgarm_gemm::IndirectInputArg157 IndirectInputArg(const T * const * const *ptr, unsigned int start_row, unsigned int start_col) : is_indirect(true) {
158 indirect.ptr = ptr;
159 indirect.start_row = start_row;
160 indirect.start_col = start_col;
161 }
162
IndirectInputArgarm_gemm::IndirectInputArg163 IndirectInputArg() : is_indirect(false) {
164 direct.base = nullptr;
165 direct.stride = 0;
166 }
167 };
168
169 namespace utils {
170
171 // get_vector_length(): Returns SVE vector length for type "T".
172 //
173 // It is required that this can be compiled by a compiler in non-SVE mode, but it must be prevented from running (at
174 // runtime) if SVE is not enabled. Typically this is used by switchyard/driver code which is built in normal mode
175 // which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime.
176 template <typename T>
get_vector_length()177 inline unsigned long get_vector_length() {
178 #if defined(__aarch64__)
179 uint64_t vl;
180
181 __asm __volatile (
182 ".inst 0x0420e3e0\n" // CNTB X0, ALL, MUL #1
183 "mov %0, X0\n"
184 : "=r" (vl)
185 :
186 : "x0"
187 );
188
189 return vl / sizeof(T);
190 #else // !defined(__aarch64__)
191 return 16 / sizeof(T);
192 #endif // defined(__aarch64__)
193 }
194
195 #ifdef ARM_COMPUTE_ENABLE_SME
196 namespace sme {
197
198 // function from misc-sve.cpp
199 extern unsigned int raw_vector_length();
200
201 template <typename T>
get_vector_length()202 inline unsigned long get_vector_length() {
203 return raw_vector_length() / sizeof(T);
204 }
205
206 } // namespace sme
207 #endif // ARM_COMPUTE_ENABLE_SME
208
209 // get_vector_length(VLType): Returns vector length for type "T".
210 //
211 // This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
212
213 template <typename T>
get_vector_length(VLType vl_type)214 inline unsigned long get_vector_length(VLType vl_type) {
215 switch (vl_type) {
216 #ifdef ARM_COMPUTE_ENABLE_SME
217 case VLType::SME:
218 return sme::get_vector_length<T>();
219 #endif // ARM_COMPUTE_ENABLE_SME
220 case VLType::SVE:
221 return get_vector_length<T>();
222 default:
223 return 16 / sizeof(T);
224 }
225 }
226
227 // get_default_activation_values(): Returns the default values for activation min and max for integer activation.
228 template <typename T>
get_default_activation_values()229 inline std::tuple<T, T> get_default_activation_values()
230 {
231 const T min = static_cast<T>(std::numeric_limits<T>::min());
232 const T max = static_cast<T>(std::numeric_limits<T>::max());
233
234 return std::make_tuple(min, max);
235 }
236
237 // get_default_activation_values(): Returns the default values for activation min and max for float activation.
238 template <>
get_default_activation_values()239 inline std::tuple<float, float> get_default_activation_values()
240 {
241 const float min = static_cast<float>(-std::numeric_limits<float>::infinity());
242 const float max = static_cast<float>(std::numeric_limits<float>::infinity());
243
244 return std::make_tuple(min, max);
245 }
246
247 #if defined(__ARM_FP16_ARGS)
248 // get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation.
249 template <>
get_default_activation_values()250 inline std::tuple<__fp16, __fp16> get_default_activation_values()
251 {
252 const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity());
253 const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity());
254
255 return std::make_tuple(min, max);
256 }
257 #endif // defined(__ARM_FP16_ARGS)
258 } // utils namespace
259 } // arm_gemm namespace
260
261 using namespace arm_gemm::utils;
262