1 #pragma once 2 3 #include <cuda.h> 4 #include <limits.h> 5 #include <math.h> 6 #include <float.h> 7 8 // NumericLimits.cuh is a holder for numeric limits definitions of commonly used 9 // types. This header is very specific to ROCm HIP and may be removed in the future. 10 // This header is derived from the legacy THCNumerics.cuh. 11 12 // The lower_bound and upper_bound constants are same as lowest and max for 13 // integral types, but are -inf and +inf for floating point types. They are 14 // useful in implementing min, max, etc. 15 16 namespace at { 17 18 template <typename T> 19 struct numeric_limits { 20 }; 21 22 // WARNING: the following at::numeric_limits definitions are there only to support 23 // HIP compilation for the moment. Use std::numeric_limits if you are not 24 // compiling for ROCm. 25 // from @colesbury: "The functions on numeric_limits aren't marked with 26 // __device__ which is why they don't work with ROCm. CUDA allows them 27 // because they're constexpr." 28 29 namespace { 30 // ROCm doesn't like INFINITY too. 31 constexpr double inf = INFINITY; 32 } 33 34 template <> 35 struct numeric_limits<bool> { lowestat::numeric_limits36 static inline __host__ __device__ bool lowest() { return false; } maxat::numeric_limits37 static inline __host__ __device__ bool max() { return true; } lower_boundat::numeric_limits38 static inline __host__ __device__ bool lower_bound() { return false; } upper_boundat::numeric_limits39 static inline __host__ __device__ bool upper_bound() { return true; } 40 }; 41 42 template <> 43 struct numeric_limits<uint8_t> { lowestat::numeric_limits44 static inline __host__ __device__ uint8_t lowest() { return 0; } maxat::numeric_limits45 static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } lower_boundat::numeric_limits46 static inline __host__ __device__ uint8_t lower_bound() { return 0; } upper_boundat::numeric_limits47 static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; } 48 }; 49 50 template <> 51 struct numeric_limits<int8_t> { lowestat::numeric_limits52 static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } maxat::numeric_limits53 static inline __host__ __device__ int8_t max() { return INT8_MAX; } lower_boundat::numeric_limits54 static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; } upper_boundat::numeric_limits55 static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } 56 }; 57 58 template <> 59 struct numeric_limits<int16_t> { lowestat::numeric_limits60 static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } maxat::numeric_limits61 static inline __host__ __device__ int16_t max() { return INT16_MAX; } lower_boundat::numeric_limits62 static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; } upper_boundat::numeric_limits63 static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } 64 }; 65 66 template <> 67 struct numeric_limits<int32_t> { lowestat::numeric_limits68 static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } maxat::numeric_limits69 static inline __host__ __device__ int32_t max() { return INT32_MAX; } lower_boundat::numeric_limits70 static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; } upper_boundat::numeric_limits71 static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } 72 }; 73 74 template <> 75 struct numeric_limits<int64_t> { 76 #ifdef _MSC_VER lowestat::numeric_limits77 static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } maxat::numeric_limits78 static inline __host__ __device__ int64_t max() { return _I64_MAX; } lower_boundat::numeric_limits79 static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; } upper_boundat::numeric_limits80 static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; } 81 #else 82 static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } 83 static inline __host__ __device__ int64_t max() { return INT64_MAX; } 84 static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; } 85 static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; } 86 #endif 87 }; 88 89 template <> 90 struct numeric_limits<at::Half> { lowestat::numeric_limits91 static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); } maxat::numeric_limits92 static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); } lower_boundat::numeric_limits93 static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); } upper_boundat::numeric_limits94 static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); } 95 }; 96 97 template <> 98 struct numeric_limits<at::BFloat16> { lowestat::numeric_limits99 static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); } maxat::numeric_limits100 static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); } lower_boundat::numeric_limits101 static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); } upper_boundat::numeric_limits102 static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); } 103 }; 104 105 template <> 106 struct numeric_limits<float> { lowestat::numeric_limits107 static inline __host__ __device__ float lowest() { return -FLT_MAX; } maxat::numeric_limits108 static inline __host__ __device__ float max() { return FLT_MAX; } lower_boundat::numeric_limits109 static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); } upper_boundat::numeric_limits110 static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); } 111 }; 112 113 template <> 114 struct numeric_limits<double> { lowestat::numeric_limits115 static inline __host__ __device__ double lowest() { return -DBL_MAX; } maxat::numeric_limits116 static inline __host__ __device__ double max() { return DBL_MAX; } lower_boundat::numeric_limits117 static inline __host__ __device__ double lower_bound() { return -inf; } upper_boundat::numeric_limits118 static inline __host__ __device__ double upper_bound() { return inf; } 119 }; 120 121 } // namespace at 122