xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/utils/VecUtils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h>
12 
13 #include <executorch/backends/vulkan/runtime/vk_api/Exception.h>
14 
15 #include <cmath>
16 #include <limits>
17 #include <numeric>
18 #include <type_traits>
19 
20 namespace vkcompute {
21 namespace utils {
22 
23 //
24 // Hashing
25 //
26 
27 /**
28  * hash_combine is taken from c10/util/hash.h, which in turn is based on
29  * implementation from Boost
30  */
hash_combine(size_t seed,size_t value)31 inline size_t hash_combine(size_t seed, size_t value) {
32   return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
33 }
34 
35 //
36 // Alignment
37 //
38 
39 template <typename Type>
align_down(const Type & number,const Type & multiple)40 inline constexpr Type align_down(const Type& number, const Type& multiple) {
41   return (number / multiple) * multiple;
42 }
43 
44 template <typename Type>
align_up(const Type & number,const Type & multiple)45 inline constexpr Type align_up(const Type& number, const Type& multiple) {
46   return align_down(number + multiple - 1, multiple);
47 }
48 
49 template <typename Type>
align_up_4(const Type & numerator)50 inline constexpr Type align_up_4(const Type& numerator) {
51   return (numerator + 3) & -4;
52 }
53 
54 template <typename Type>
div_up(const Type & numerator,const Type & denominator)55 inline constexpr Type div_up(const Type& numerator, const Type& denominator) {
56   return (numerator + denominator - 1) / denominator;
57 }
58 
59 template <typename Type>
div_up_4(const Type & numerator)60 inline constexpr Type div_up_4(const Type& numerator) {
61   return (numerator + 3) / 4;
62 }
63 
64 //
65 // Casting Utilities
66 //
67 
68 namespace detail {
69 
70 /*
71  * x cannot be less than 0 if x is unsigned
72  */
73 template <typename T>
is_negative(const T &,std::true_type)74 static inline constexpr bool is_negative(
75     const T& /*x*/,
76     std::true_type /*is_unsigned*/) {
77   return false;
78 }
79 
80 /*
81  * check if x is less than 0 if x is signed
82  */
83 template <typename T>
is_negative(const T & x,std::false_type)84 static inline constexpr bool is_negative(
85     const T& x,
86     std::false_type /*is_unsigned*/) {
87   return x < T(0);
88 }
89 
90 /*
91  * Returns true if x < 0
92  */
93 template <typename T>
is_negative(const T & x)94 inline constexpr bool is_negative(const T& x) {
95   return is_negative(x, std::is_unsigned<T>());
96 }
97 
98 /*
99  * Returns true if x < lowest(Limit); standard comparison
100  */
101 template <typename Limit, typename T>
less_than_lowest(const T & x,std::false_type,std::false_type)102 static inline constexpr bool less_than_lowest(
103     const T& x,
104     std::false_type /*limit_is_unsigned*/,
105     std::false_type /*x_is_unsigned*/) {
106   return x < std::numeric_limits<Limit>::lowest();
107 }
108 
109 /*
110  * Limit can contained negative values, but x cannot; return false
111  */
112 template <typename Limit, typename T>
less_than_lowest(const T &,std::false_type,std::true_type)113 static inline constexpr bool less_than_lowest(
114     const T& /*x*/,
115     std::false_type /*limit_is_unsigned*/,
116     std::true_type /*x_is_unsigned*/) {
117   return false;
118 }
119 
120 /*
121  * Limit cannot contained negative values, but x can; check if x is negative
122  */
123 template <typename Limit, typename T>
less_than_lowest(const T & x,std::true_type,std::false_type)124 static inline constexpr bool less_than_lowest(
125     const T& x,
126     std::true_type /*limit_is_unsigned*/,
127     std::false_type /*x_is_unsigned*/) {
128   return x < T(0);
129 }
130 
131 /*
132  * Both x and Limit cannot be negative; return false
133  */
134 template <typename Limit, typename T>
less_than_lowest(const T &,std::true_type,std::true_type)135 static inline constexpr bool less_than_lowest(
136     const T& /*x*/,
137     std::true_type /*limit_is_unsigned*/,
138     std::true_type /*x_is_unsigned*/) {
139   return false;
140 }
141 
142 /*
143  * Returns true if x is less than the lowest value of type T
144  */
145 template <typename Limit, typename T>
less_than_lowest(const T & x)146 inline constexpr bool less_than_lowest(const T& x) {
147   return less_than_lowest<Limit>(
148       x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
149 }
150 
151 // Suppress sign compare warning when compiling with GCC
152 // as later does not account for short-circuit rule before
153 // raising the warning, see https://godbolt.org/z/Tr3Msnz99
154 #ifdef __GNUC__
155 #pragma GCC diagnostic push
156 #pragma GCC diagnostic ignored "-Wsign-compare"
157 #endif
158 
159 /*
160  * Returns true if x is greater than the greatest value of the type Limit
161  */
162 template <typename Limit, typename T>
greater_than_max(const T & x)163 inline constexpr bool greater_than_max(const T& x) {
164   constexpr bool can_overflow =
165       std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
166   return can_overflow && x > std::numeric_limits<Limit>::max();
167 }
168 
169 #ifdef __GNUC__
170 #pragma GCC diagnostic pop
171 #endif
172 
173 template <typename To, typename From>
174 std::enable_if_t<
175     std::is_integral<From>::value && !std::is_same<From, bool>::value,
176     bool>
overflows(From f)177 overflows(From f) {
178   using limit = std::numeric_limits<To>;
179   // Casting from signed to unsigned; allow for negative numbers to wrap using
180   // two's complement arithmetic.
181   if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
182     return greater_than_max<To>(f) ||
183         (is_negative(f) && -static_cast<uint64_t>(f) > limit::max());
184   }
185   // standard case, check if f is outside the range of type To
186   else {
187     return less_than_lowest<To>(f) || greater_than_max<To>(f);
188   }
189 }
190 
191 template <typename To, typename From>
overflows(From f)192 std::enable_if_t<std::is_floating_point<From>::value, bool> overflows(From f) {
193   using limit = std::numeric_limits<To>;
194   if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
195     return false;
196   }
197   return f < limit::lowest() || f > limit::max();
198 }
199 
200 template <typename To, typename From>
safe_downcast(const From & v)201 inline constexpr To safe_downcast(const From& v) {
202   VK_CHECK_COND(!overflows<To>(v), "Cast failed: out of range!");
203   return static_cast<To>(v);
204 }
205 
206 template <typename To, typename From>
is_signed_to_unsigned()207 inline constexpr bool is_signed_to_unsigned() {
208   return std::is_signed<From>::value && std::is_unsigned<To>::value;
209 }
210 
211 } // namespace detail
212 
213 template <
214     typename To,
215     typename From,
216     std::enable_if_t<detail::is_signed_to_unsigned<To, From>(), bool> = true>
safe_downcast(const From & v)217 inline constexpr To safe_downcast(const From& v) {
218   VK_CHECK_COND(v >= From{}, "Cast failed: negative signed to unsigned!");
219   return detail::safe_downcast<To, From>(v);
220 }
221 
222 template <
223     typename To,
224     typename From,
225     std::enable_if_t<!detail::is_signed_to_unsigned<To, From>(), bool> = true>
safe_downcast(const From & v)226 inline constexpr To safe_downcast(const From& v) {
227   return detail::safe_downcast<To, From>(v);
228 }
229 
230 //
231 // Vector Types
232 //
233 
234 namespace detail {
235 
236 template <typename Type, uint32_t N>
237 struct vec final {
238   // NOLINTNEXTLINE
239   Type data[N];
240 
241   vec() = default;
242 
243   // Standard constructor with initializer list
vecfinal244   vec(std::initializer_list<Type> values) {
245     VK_CHECK_COND(values.size() == N);
246     std::copy(values.begin(), values.end(), data);
247   }
248 
249   // Conversion constructor from an _integral_ vec type. Note that this is only
250   // defined if `OtherType` is an integral type to disallow implicit narrowing.
251   template <
252       typename OtherType,
253       typename std::enable_if<
254           !std::is_same<Type, OtherType>::value &&
255               std::is_integral<OtherType>::value,
256           int>::type = 0>
vecfinal257   /* implicit */ vec(const vec<OtherType, N>& other) {
258     for (int i = 0; i < N; ++i) {
259       data[i] = safe_downcast<Type>(other[i]);
260     }
261   }
262 
263   const Type& operator[](const uint32_t& i) const {
264     VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!");
265     return data[i];
266   }
267 
268   Type& operator[](const uint32_t& i) {
269     VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!");
270     return data[i];
271   }
272 };
273 
274 } // namespace detail
275 
276 template <uint32_t N>
277 using ivec = detail::vec<int32_t, N>;
278 using ivec2 = ivec<2u>;
279 using ivec3 = ivec<3u>;
280 using ivec4 = ivec<4u>;
281 
282 template <uint32_t N>
283 using uvec = detail::vec<uint32_t, N>;
284 using uvec2 = uvec<2u>;
285 using uvec3 = uvec<3u>;
286 using uvec4 = uvec<4u>;
287 
288 template <uint32_t N>
289 using vec = detail::vec<float, N>;
290 using vec2 = vec<2u>;
291 using vec3 = vec<3u>;
292 using vec4 = vec<4u>;
293 
294 // uvec3 is the type representing tensor extents. Useful for debugging.
295 inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
296   os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ")";
297   return os;
298 }
299 
300 inline std::ostream& operator<<(std::ostream& os, const ivec3& v) {
301   os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ")";
302   return os;
303 }
304 
305 inline std::ostream& operator<<(std::ostream& os, const uvec4& v) {
306   os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ", " << v[3u] << ")";
307   return os;
308 }
309 
310 inline std::ostream& operator<<(std::ostream& os, const ivec4& v) {
311   os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ", " << v[3u] << ")";
312   return os;
313 }
314 
315 template <typename T, uint32_t N>
divup_vec(const detail::vec<T,N> & a,const detail::vec<T,N> & b)316 inline detail::vec<T, N> divup_vec(
317     const detail::vec<T, N>& a,
318     const detail::vec<T, N>& b) {
319   detail::vec<T, N> result;
320   for (uint32_t i = 0; i < N; ++i) {
321     result[i] = utils::div_up(a[i], b[i]);
322   }
323   return result;
324 }
325 
326 //
327 // std::vector<T> Handling
328 //
329 
330 /*
331  * Utility function to perform indexing on an std::vector<T>. Negative indexing
332  * is allowed. For instance, passing an index of -1 will retrieve the last
333  * element. If the requested index is out of bounds, then 1u will be returned.
334  */
335 template <typename T>
val_at(const int64_t index,const std::vector<T> & sizes)336 inline T val_at(const int64_t index, const std::vector<T>& sizes) {
337   const int64_t ndim = static_cast<int64_t>(sizes.size());
338   if (index >= 0) {
339     return index >= ndim ? 1 : sizes[index];
340   } else {
341     return ndim + index < 0 ? 1 : sizes[ndim + index];
342   }
343 }
344 
345 inline ivec2 make_ivec2(
346     const std::vector<int64_t>& ints,
347     bool reverse = false) {
348   VK_CHECK_COND(ints.size() == 2);
349   if (reverse) {
350     return {safe_downcast<int32_t>(ints[1]), safe_downcast<int32_t>(ints[0])};
351   } else {
352     return {safe_downcast<int32_t>(ints[0]), safe_downcast<int32_t>(ints[1])};
353   }
354 }
355 
356 inline ivec3 make_ivec3(
357     const std::vector<int64_t>& ints,
358     bool reverse = false) {
359   VK_CHECK_COND(ints.size() == 3);
360   if (reverse) {
361     return {
362         safe_downcast<int32_t>(ints[2]),
363         safe_downcast<int32_t>(ints[1]),
364         safe_downcast<int32_t>(ints[0]),
365     };
366   } else {
367     return {
368         safe_downcast<int32_t>(ints[0]),
369         safe_downcast<int32_t>(ints[1]),
370         safe_downcast<int32_t>(ints[2]),
371     };
372   }
373 }
374 
375 inline ivec4 make_ivec4(
376     const std::vector<int64_t>& ints,
377     bool reverse = false) {
378   VK_CHECK_COND(ints.size() == 4);
379   if (reverse) {
380     return {
381         safe_downcast<int32_t>(ints[3]),
382         safe_downcast<int32_t>(ints[2]),
383         safe_downcast<int32_t>(ints[1]),
384         safe_downcast<int32_t>(ints[0]),
385     };
386   } else {
387     return {
388         safe_downcast<int32_t>(ints[0]),
389         safe_downcast<int32_t>(ints[1]),
390         safe_downcast<int32_t>(ints[2]),
391         safe_downcast<int32_t>(ints[3]),
392     };
393   }
394 }
395 
make_ivec4_prepadded1(const std::vector<int64_t> & ints)396 inline ivec4 make_ivec4_prepadded1(const std::vector<int64_t>& ints) {
397   VK_CHECK_COND(ints.size() <= 4);
398 
399   ivec4 result = {1, 1, 1, 1};
400   size_t base = 4 - ints.size();
401   for (size_t i = 0; i < ints.size(); ++i) {
402     result[i + base] = safe_downcast<int32_t>(ints[i]);
403   }
404 
405   return result;
406 }
407 
make_ivec3(uvec3 ints)408 inline ivec3 make_ivec3(uvec3 ints) {
409   return {
410       safe_downcast<int32_t>(ints[0u]),
411       safe_downcast<int32_t>(ints[1u]),
412       safe_downcast<int32_t>(ints[2u])};
413 }
414 
make_uvec3(ivec3 ints)415 inline uvec3 make_uvec3(ivec3 ints) {
416   return {
417       safe_downcast<uint32_t>(ints[0u]),
418       safe_downcast<uint32_t>(ints[1u]),
419       safe_downcast<uint32_t>(ints[2u])};
420 }
421 
422 /*
423  * Given an vector of up to 4 uint64_t representing the sizes of a tensor,
424  * constructs a uvec4 containing those elements in reverse order.
425  */
make_whcn_uvec4(const std::vector<int64_t> & arr)426 inline uvec4 make_whcn_uvec4(const std::vector<int64_t>& arr) {
427   uint32_t w = safe_downcast<uint32_t>(val_at(-1, arr));
428   uint32_t h = safe_downcast<uint32_t>(val_at(-2, arr));
429   uint32_t c = safe_downcast<uint32_t>(val_at(-3, arr));
430   uint32_t n = safe_downcast<uint32_t>(val_at(-4, arr));
431 
432   return {w, h, c, n};
433 }
434 
435 /*
436  * Given an vector of up to 4 int64_t representing the sizes of a tensor,
437  * constructs an ivec4 containing those elements in reverse order.
438  */
make_whcn_ivec4(const std::vector<int64_t> & arr)439 inline ivec4 make_whcn_ivec4(const std::vector<int64_t>& arr) {
440   int32_t w = val_at(-1, arr);
441   int32_t h = val_at(-2, arr);
442   int32_t c = val_at(-3, arr);
443   int32_t n = val_at(-4, arr);
444 
445   return {w, h, c, n};
446 }
447 
448 /*
449  * Wrapper around std::accumulate that accumulates values of a container of
450  * integral types into int64_t. Taken from `multiply_integers` in
451  * <c10/util/accumulate.h>
452  */
453 template <
454     typename C,
455     std::enable_if_t<std::is_integral<typename C::value_type>::value, int> = 0>
multiply_integers(const C & container)456 inline int64_t multiply_integers(const C& container) {
457   return std::accumulate(
458       container.begin(),
459       container.end(),
460       static_cast<int64_t>(1),
461       std::multiplies<>());
462 }
463 
464 /*
465  * Product of integer elements referred to by iterators; accumulates into the
466  * int64_t datatype. Taken from `multiply_integers` in <c10/util/accumulate.h>
467  */
468 template <
469     typename Iter,
470     std::enable_if_t<
471         std::is_integral<
472             typename std::iterator_traits<Iter>::value_type>::value,
473         int> = 0>
multiply_integers(Iter begin,Iter end)474 inline int64_t multiply_integers(Iter begin, Iter end) {
475   // std::accumulate infers return type from `init` type, so if the `init` type
476   // is not large enough to hold the result, computation can overflow. We use
477   // `int64_t` here to avoid this.
478   return std::accumulate(
479       begin, end, static_cast<int64_t>(1), std::multiplies<>());
480 }
481 
482 } // namespace utils
483 } // namespace vkcompute
484