/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include namespace vkcompute { namespace utils { // // Hashing // /** * hash_combine is taken from c10/util/hash.h, which in turn is based on * implementation from Boost */ inline size_t hash_combine(size_t seed, size_t value) { return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); } // // Alignment // template inline constexpr Type align_down(const Type& number, const Type& multiple) { return (number / multiple) * multiple; } template inline constexpr Type align_up(const Type& number, const Type& multiple) { return align_down(number + multiple - 1, multiple); } template inline constexpr Type align_up_4(const Type& numerator) { return (numerator + 3) & -4; } template inline constexpr Type div_up(const Type& numerator, const Type& denominator) { return (numerator + denominator - 1) / denominator; } template inline constexpr Type div_up_4(const Type& numerator) { return (numerator + 3) / 4; } // // Casting Utilities // namespace detail { /* * x cannot be less than 0 if x is unsigned */ template static inline constexpr bool is_negative( const T& /*x*/, std::true_type /*is_unsigned*/) { return false; } /* * check if x is less than 0 if x is signed */ template static inline constexpr bool is_negative( const T& x, std::false_type /*is_unsigned*/) { return x < T(0); } /* * Returns true if x < 0 */ template inline constexpr bool is_negative(const T& x) { return is_negative(x, std::is_unsigned()); } /* * Returns true if x < lowest(Limit); standard comparison */ template static inline constexpr bool less_than_lowest( const T& x, std::false_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { return x < std::numeric_limits::lowest(); } /* * Limit can contained negative values, but x cannot; return false */ template static inline constexpr bool less_than_lowest( const T& /*x*/, std::false_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { return false; } /* * Limit cannot contained negative values, but x can; check if x is negative */ template static inline constexpr bool less_than_lowest( const T& x, std::true_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { return x < T(0); } /* * Both x and Limit cannot be negative; return false */ template static inline constexpr bool less_than_lowest( const T& /*x*/, std::true_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { return false; } /* * Returns true if x is less than the lowest value of type T */ template inline constexpr bool less_than_lowest(const T& x) { return less_than_lowest( x, std::is_unsigned(), std::is_unsigned()); } // Suppress sign compare warning when compiling with GCC // as later does not account for short-circuit rule before // raising the warning, see https://godbolt.org/z/Tr3Msnz99 #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wsign-compare" #endif /* * Returns true if x is greater than the greatest value of the type Limit */ template inline constexpr bool greater_than_max(const T& x) { constexpr bool can_overflow = std::numeric_limits::digits > std::numeric_limits::digits; return can_overflow && x > std::numeric_limits::max(); } #ifdef __GNUC__ #pragma GCC diagnostic pop #endif template std::enable_if_t< std::is_integral::value && !std::is_same::value, bool> overflows(From f) { using limit = std::numeric_limits; // Casting from signed to unsigned; allow for negative numbers to wrap using // two's complement arithmetic. if (!limit::is_signed && std::numeric_limits::is_signed) { return greater_than_max(f) || (is_negative(f) && -static_cast(f) > limit::max()); } // standard case, check if f is outside the range of type To else { return less_than_lowest(f) || greater_than_max(f); } } template std::enable_if_t::value, bool> overflows(From f) { using limit = std::numeric_limits; if (limit::has_infinity && std::isinf(static_cast(f))) { return false; } return f < limit::lowest() || f > limit::max(); } template inline constexpr To safe_downcast(const From& v) { VK_CHECK_COND(!overflows(v), "Cast failed: out of range!"); return static_cast(v); } template inline constexpr bool is_signed_to_unsigned() { return std::is_signed::value && std::is_unsigned::value; } } // namespace detail template < typename To, typename From, std::enable_if_t(), bool> = true> inline constexpr To safe_downcast(const From& v) { VK_CHECK_COND(v >= From{}, "Cast failed: negative signed to unsigned!"); return detail::safe_downcast(v); } template < typename To, typename From, std::enable_if_t(), bool> = true> inline constexpr To safe_downcast(const From& v) { return detail::safe_downcast(v); } // // Vector Types // namespace detail { template struct vec final { // NOLINTNEXTLINE Type data[N]; vec() = default; // Standard constructor with initializer list vec(std::initializer_list values) { VK_CHECK_COND(values.size() == N); std::copy(values.begin(), values.end(), data); } // Conversion constructor from an _integral_ vec type. Note that this is only // defined if `OtherType` is an integral type to disallow implicit narrowing. template < typename OtherType, typename std::enable_if< !std::is_same::value && std::is_integral::value, int>::type = 0> /* implicit */ vec(const vec& other) { for (int i = 0; i < N; ++i) { data[i] = safe_downcast(other[i]); } } const Type& operator[](const uint32_t& i) const { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } Type& operator[](const uint32_t& i) { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } }; } // namespace detail template using ivec = detail::vec; using ivec2 = ivec<2u>; using ivec3 = ivec<3u>; using ivec4 = ivec<4u>; template using uvec = detail::vec; using uvec2 = uvec<2u>; using uvec3 = uvec<3u>; using uvec4 = uvec<4u>; template using vec = detail::vec; using vec2 = vec<2u>; using vec3 = vec<3u>; using vec4 = vec<4u>; // uvec3 is the type representing tensor extents. Useful for debugging. inline std::ostream& operator<<(std::ostream& os, const uvec3& v) { os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ")"; return os; } inline std::ostream& operator<<(std::ostream& os, const ivec3& v) { os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ")"; return os; } inline std::ostream& operator<<(std::ostream& os, const uvec4& v) { os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ", " << v[3u] << ")"; return os; } inline std::ostream& operator<<(std::ostream& os, const ivec4& v) { os << "(" << v[0u] << ", " << v[1u] << ", " << v[2u] << ", " << v[3u] << ")"; return os; } template inline detail::vec divup_vec( const detail::vec& a, const detail::vec& b) { detail::vec result; for (uint32_t i = 0; i < N; ++i) { result[i] = utils::div_up(a[i], b[i]); } return result; } // // std::vector Handling // /* * Utility function to perform indexing on an std::vector. Negative indexing * is allowed. For instance, passing an index of -1 will retrieve the last * element. If the requested index is out of bounds, then 1u will be returned. */ template inline T val_at(const int64_t index, const std::vector& sizes) { const int64_t ndim = static_cast(sizes.size()); if (index >= 0) { return index >= ndim ? 1 : sizes[index]; } else { return ndim + index < 0 ? 1 : sizes[ndim + index]; } } inline ivec2 make_ivec2( const std::vector& ints, bool reverse = false) { VK_CHECK_COND(ints.size() == 2); if (reverse) { return {safe_downcast(ints[1]), safe_downcast(ints[0])}; } else { return {safe_downcast(ints[0]), safe_downcast(ints[1])}; } } inline ivec3 make_ivec3( const std::vector& ints, bool reverse = false) { VK_CHECK_COND(ints.size() == 3); if (reverse) { return { safe_downcast(ints[2]), safe_downcast(ints[1]), safe_downcast(ints[0]), }; } else { return { safe_downcast(ints[0]), safe_downcast(ints[1]), safe_downcast(ints[2]), }; } } inline ivec4 make_ivec4( const std::vector& ints, bool reverse = false) { VK_CHECK_COND(ints.size() == 4); if (reverse) { return { safe_downcast(ints[3]), safe_downcast(ints[2]), safe_downcast(ints[1]), safe_downcast(ints[0]), }; } else { return { safe_downcast(ints[0]), safe_downcast(ints[1]), safe_downcast(ints[2]), safe_downcast(ints[3]), }; } } inline ivec4 make_ivec4_prepadded1(const std::vector& ints) { VK_CHECK_COND(ints.size() <= 4); ivec4 result = {1, 1, 1, 1}; size_t base = 4 - ints.size(); for (size_t i = 0; i < ints.size(); ++i) { result[i + base] = safe_downcast(ints[i]); } return result; } inline ivec3 make_ivec3(uvec3 ints) { return { safe_downcast(ints[0u]), safe_downcast(ints[1u]), safe_downcast(ints[2u])}; } inline uvec3 make_uvec3(ivec3 ints) { return { safe_downcast(ints[0u]), safe_downcast(ints[1u]), safe_downcast(ints[2u])}; } /* * Given an vector of up to 4 uint64_t representing the sizes of a tensor, * constructs a uvec4 containing those elements in reverse order. */ inline uvec4 make_whcn_uvec4(const std::vector& arr) { uint32_t w = safe_downcast(val_at(-1, arr)); uint32_t h = safe_downcast(val_at(-2, arr)); uint32_t c = safe_downcast(val_at(-3, arr)); uint32_t n = safe_downcast(val_at(-4, arr)); return {w, h, c, n}; } /* * Given an vector of up to 4 int64_t representing the sizes of a tensor, * constructs an ivec4 containing those elements in reverse order. */ inline ivec4 make_whcn_ivec4(const std::vector& arr) { int32_t w = val_at(-1, arr); int32_t h = val_at(-2, arr); int32_t c = val_at(-3, arr); int32_t n = val_at(-4, arr); return {w, h, c, n}; } /* * Wrapper around std::accumulate that accumulates values of a container of * integral types into int64_t. Taken from `multiply_integers` in * */ template < typename C, std::enable_if_t::value, int> = 0> inline int64_t multiply_integers(const C& container) { return std::accumulate( container.begin(), container.end(), static_cast(1), std::multiplies<>()); } /* * Product of integer elements referred to by iterators; accumulates into the * int64_t datatype. Taken from `multiply_integers` in */ template < typename Iter, std::enable_if_t< std::is_integral< typename std::iterator_traits::value_type>::value, int> = 0> inline int64_t multiply_integers(Iter begin, Iter end) { // std::accumulate infers return type from `init` type, so if the `init` type // is not large enough to hold the result, computation can overflow. We use // `int64_t` here to avoid this. return std::accumulate( begin, end, static_cast(1), std::multiplies<>()); } } // namespace utils } // namespace vkcompute