xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/NumericLimits.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cuda.h>
4 #include <limits.h>
5 #include <math.h>
6 #include <float.h>
7 
8 // NumericLimits.cuh is a holder for numeric limits definitions of commonly used
9 // types. This header is very specific to ROCm HIP and may be removed in the future.
10 // This header is derived from the legacy THCNumerics.cuh.
11 
12 // The lower_bound and upper_bound constants are same as lowest and max for
13 // integral types, but are -inf and +inf for floating point types. They are
14 // useful in implementing min, max, etc.
15 
16 namespace at {
17 
18 template <typename T>
19 struct numeric_limits {
20 };
21 
22 // WARNING: the following at::numeric_limits definitions are there only to support
23 //          HIP compilation for the moment. Use std::numeric_limits if you are not
24 //          compiling for ROCm.
25 //          from @colesbury: "The functions on numeric_limits aren't marked with
26 //          __device__ which is why they don't work with ROCm. CUDA allows them
27 //          because they're constexpr."
28 
29 namespace {
30   // ROCm doesn't like INFINITY too.
31   constexpr double inf = INFINITY;
32 }
33 
34 template <>
35 struct numeric_limits<bool> {
lowestat::numeric_limits36   static inline __host__ __device__ bool lowest() { return false; }
maxat::numeric_limits37   static inline __host__ __device__ bool max() { return true; }
lower_boundat::numeric_limits38   static inline __host__ __device__ bool lower_bound() { return false; }
upper_boundat::numeric_limits39   static inline __host__ __device__ bool upper_bound() { return true; }
40 };
41 
42 template <>
43 struct numeric_limits<uint8_t> {
lowestat::numeric_limits44   static inline __host__ __device__ uint8_t lowest() { return 0; }
maxat::numeric_limits45   static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
lower_boundat::numeric_limits46   static inline __host__ __device__ uint8_t lower_bound() { return 0; }
upper_boundat::numeric_limits47   static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
48 };
49 
50 template <>
51 struct numeric_limits<int8_t> {
lowestat::numeric_limits52   static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
maxat::numeric_limits53   static inline __host__ __device__ int8_t max() { return INT8_MAX; }
lower_boundat::numeric_limits54   static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
upper_boundat::numeric_limits55   static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
56 };
57 
58 template <>
59 struct numeric_limits<int16_t> {
lowestat::numeric_limits60   static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
maxat::numeric_limits61   static inline __host__ __device__ int16_t max() { return INT16_MAX; }
lower_boundat::numeric_limits62   static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
upper_boundat::numeric_limits63   static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
64 };
65 
66 template <>
67 struct numeric_limits<int32_t> {
lowestat::numeric_limits68   static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
maxat::numeric_limits69   static inline __host__ __device__ int32_t max() { return INT32_MAX; }
lower_boundat::numeric_limits70   static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
upper_boundat::numeric_limits71   static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
72 };
73 
74 template <>
75 struct numeric_limits<int64_t> {
76 #ifdef _MSC_VER
lowestat::numeric_limits77   static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
maxat::numeric_limits78   static inline __host__ __device__ int64_t max() { return _I64_MAX; }
lower_boundat::numeric_limits79   static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
upper_boundat::numeric_limits80   static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
81 #else
82   static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
83   static inline __host__ __device__ int64_t max() { return INT64_MAX; }
84   static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
85   static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
86 #endif
87 };
88 
89 template <>
90 struct numeric_limits<at::Half> {
lowestat::numeric_limits91   static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
maxat::numeric_limits92   static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
lower_boundat::numeric_limits93   static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
upper_boundat::numeric_limits94   static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
95 };
96 
97 template <>
98 struct numeric_limits<at::BFloat16> {
lowestat::numeric_limits99   static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
maxat::numeric_limits100   static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
lower_boundat::numeric_limits101   static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
upper_boundat::numeric_limits102   static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
103 };
104 
105 template <>
106 struct numeric_limits<float> {
lowestat::numeric_limits107   static inline __host__ __device__ float lowest() { return -FLT_MAX; }
maxat::numeric_limits108   static inline __host__ __device__ float max() { return FLT_MAX; }
lower_boundat::numeric_limits109   static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
upper_boundat::numeric_limits110   static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
111 };
112 
113 template <>
114 struct numeric_limits<double> {
lowestat::numeric_limits115   static inline __host__ __device__ double lowest() { return -DBL_MAX; }
maxat::numeric_limits116   static inline __host__ __device__ double max() { return DBL_MAX; }
lower_boundat::numeric_limits117   static inline __host__ __device__ double lower_bound() { return -inf; }
upper_boundat::numeric_limits118   static inline __host__ __device__ double upper_bound() { return inf; }
119 };
120 
121 } // namespace at
122