xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/Atomic.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cuda.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/BFloat16.h>
6 
7 #include <ATen/NumericUtils.h>
8 
9 #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
10 #include <cuda_bf16.h>
11 #endif
12 
13 template <typename T>
14 struct AtomicFPOp;
15 
16 template <>
17 struct AtomicFPOp<at::Half> {
18   template <typename func_t>
operator ()AtomicFPOp19   inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
20     unsigned int * address_as_ui =
21       (unsigned int *) ((char *)address - ((size_t)address & 2));
22     unsigned int old = *address_as_ui;
23     unsigned int assumed;
24 
25     at::Half hsum;
26     do {
27       assumed = old;
28       hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
29       hsum = func(hsum, val);
30       old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
31       old = atomicCAS(address_as_ui, assumed, old);
32     } while (assumed != old);
33     hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
34     return hsum;
35   }
36 };
37 
38 template <>
39 struct AtomicFPOp<at::BFloat16> {
40   template <typename func_t>
operator ()AtomicFPOp41   inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
42     unsigned int * address_as_ui =
43       (unsigned int *) ((char *)address - ((size_t)address & 2));
44     unsigned int old = *address_as_ui;
45     unsigned int assumed;
46 
47     at::BFloat16 bsum;
48     do {
49       assumed = old;
50       bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
51       bsum = func(bsum, val);
52       old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
53       old = atomicCAS(address_as_ui, assumed, old);
54     } while (assumed != old);
55     bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
56     return bsum.x;
57   }
58 };
59 
60 template <>
61 struct AtomicFPOp<double> {
62   template <typename func_t>
operator ()AtomicFPOp63   inline __device__ double operator() (double * address, double val, const func_t& func) {
64     unsigned long long int* address_as_ull = (unsigned long long int*)address;
65     unsigned long long int old = *address_as_ull;
66     unsigned long long int assumed;
67 
68     do {
69       assumed = old;
70       old = atomicCAS(address_as_ull, assumed, func(val, assumed));
71       // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
72     } while (assumed != old);
73 
74     return __longlong_as_double(old);
75   }
76 };
77 
78 #define ATOMIC_INTEGER_IMPL(NAME)                                                                                      \
79 template <typename T, size_t n>                                                                                        \
80 struct Atomic##NAME##IntegerImpl;                                                                                      \
81                                                                                                                        \
82 template<typename T>                                                                                                   \
83 struct Atomic##NAME##IntegerImpl<T, 1> {                                                                               \
84   template <typename func_t>                                                                                           \
85   inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
86     size_t offset = (size_t)address & 3;                                                                               \
87     uint32_t * address_as_ui = (uint32_t *)((char *)address - offset);                                                 \
88     uint32_t old = *address_as_ui;                                                                                     \
89     uint32_t shift = offset * 8;                                                                                       \
90     uint32_t old_byte;                                                                                                 \
91     uint32_t newval;                                                                                                   \
92     uint32_t assumed;                                                                                                  \
93                                                                                                                        \
94     do {                                                                                                               \
95       assumed = old;                                                                                                   \
96       old_byte = (old >> shift) & 0xff;                                                                                \
97       newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte)));                                              \
98       newval = (old & ~(0x000000ff << shift)) | (newval << shift);                                                     \
99       old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
100     } while (assumed != old);                                                                                          \
101   }                                                                                                                    \
102 };                                                                                                                     \
103                                                                                                                        \
104 template<typename T>                                                                                                   \
105 struct Atomic##NAME##IntegerImpl<T, 2> {                                                                               \
106   template <typename func_t>                                                                                           \
107   inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
108     size_t offset = (size_t)address & 2;                                                                               \
109     uint32_t * address_as_ui = (uint32_t *)((char *)address - offset);                                                 \
110     bool is_32_align = offset;                                                                                         \
111     uint32_t old = *address_as_ui;                                                                                     \
112     uint32_t old_bytes;                                                                                                \
113     uint32_t newval;                                                                                                   \
114     uint32_t assumed;                                                                                                  \
115                                                                                                                        \
116     do {                                                                                                               \
117       assumed = old;                                                                                                   \
118       old_bytes = is_32_align ? old >> 16 : old & 0xffff;                                                              \
119       newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes)));                                            \
120       newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval;                            \
121       old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
122     } while (assumed != old);                                                                                          \
123   }                                                                                                                    \
124 };                                                                                                                     \
125                                                                                                                        \
126 template<typename T>                                                                                                   \
127 struct Atomic##NAME##IntegerImpl<T, 4> {                                                                               \
128   template <typename func_t>                                                                                           \
129   inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
130     uint32_t * address_as_ui = (uint32_t *) (address);                                                                 \
131     uint32_t old = *address_as_ui;                                                                                     \
132     uint32_t newval;                                                                                                   \
133     uint32_t assumed;                                                                                                  \
134                                                                                                                        \
135     do {                                                                                                               \
136       assumed = old;                                                                                                   \
137       newval = static_cast<uint32_t>(func(val, static_cast<T>(old)));                                                  \
138       old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
139     } while (assumed != old);                                                                                          \
140   }                                                                                                                    \
141 };                                                                                                                     \
142                                                                                                                        \
143 template<typename T>                                                                                                   \
144 struct Atomic##NAME##IntegerImpl<T, 8> {                                                                               \
145   template <typename func_t>                                                                                           \
146   inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
147     unsigned long long * address_as_ui = (unsigned long long *) (address);                                             \
148     unsigned long long old = *address_as_ui;                                                                           \
149     unsigned long long newval;                                                                                         \
150     unsigned long long assumed;                                                                                        \
151                                                                                                                        \
152     do {                                                                                                               \
153       assumed = old;                                                                                                   \
154       newval = static_cast<uint64_t>(func(val, static_cast<T>(old)));                                                  \
155       old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
156     } while (assumed != old);                                                                                          \
157   }                                                                                                                    \
158 };
159 
160 
161 # define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE)                                                                           \
162 inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) {                                             \
163 Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address,                                                             \
164                                                       val,                                                             \
165                                                       [](DTYPE a, DTYPE b) {                                           \
166                                                           return OP;                                                   \
167                                                       });                                                              \
168 }                                                                                                                      \
169 
170 ATOMIC_INTEGER_IMPL(Add)
171 GPU_ATOMIC_INTEGER(Add, a || b, bool)
172 
173 // Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
gpuAtomicAdd(uint8_t * address,uint8_t val)174 inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
175   AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
176                                                    val,
177                                                    [](uint8_t a, uint8_t b) {
178                                                       return a + b;
179                                                    });
180 }
181 
gpuAtomicAdd(int8_t * address,int8_t val)182 inline  __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
183   AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
184                                                  val,
185                                                  [](int8_t a, int8_t b) {
186                                                    return a + b;
187                                                  });
188 }
189 
gpuAtomicAdd(int16_t * address,int16_t val)190 inline  __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
191   AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
192                                                    val,
193                                                    [](int16_t a, int16_t b) {
194                                                      return a + b;
195                                                    });
196 }
197 
gpuAtomicAdd(int32_t * address,int32_t val)198 inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
199   return atomicAdd(address, val);
200 }
201 
gpuAtomicAdd(int64_t * address,int64_t val)202 inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
203 #if defined(USE_ROCM)
204   __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
205 #else
206   static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
207   atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
208 #endif
209 }
210 
gpuAtomicAdd(at::Half * address,at::Half val)211 inline  __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
212 #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
213   return AtomicFPOp<at::Half>()(address, val,
214                                 [](at::Half hsum, at::Half val) {
215                                   return hsum + val;
216                                 });
217 #else
218   return atomicAdd(reinterpret_cast<__half*>(address), val);
219 #endif
220 }
221 
gpuAtomicAdd(at::BFloat16 * address,at::BFloat16 val)222 inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
223 #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
224 return AtomicFPOp<at::BFloat16>()(address, val,
225                                   [](at::BFloat16 bsum, at::BFloat16 val) {
226                                     return bsum + val;
227                                   });
228 #else
229   __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
230   return *reinterpret_cast<c10::BFloat16*>(&r);
231 #endif
232 }
233 
234 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
235 // from CUDA C Programmic Guide
atomicAdd(double * address,double val)236 inline __device__ double atomicAdd(double* address, double val)
237 #if defined(__clang__) && defined(__CUDA__)
238 #pragma GCC diagnostic push
239 #pragma GCC diagnostic ignored "-Wgcc-compat"
240     __attribute__((enable_if(true, "")))
241 #pragma GCC diagnostic pop
242 #endif
243 {
244 
245   return AtomicFPOp<double>()(address, val,
246                               [](double val, unsigned long long int assumed) {
247                                 return __double_as_longlong(val + __longlong_as_double(assumed));
248                               });
249 }
250 #elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
251 
252 /* Note [hip-clang differences to hcc]
253  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
254  * The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
255  * It exports the __HIP__ macro, we can hence differentiate between hcc and
256  * hip-clang. In the below, hcc only received support for atomicAdd with double
257  * typing after work week 18312. hip-clang had support from the first version.
258  * In general, the code-visible differences between hip-clang and hcc will be
259  * minimal.
260  */
261 
262 #if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
263   // This needs to be defined for the host side pass
atomicAdd(double * address,double val)264   inline  __device__  double atomicAdd(double *address, double val) { }
265 #endif
266 #endif
267 
gpuAtomicAdd(double * address,double val)268 inline __device__ double gpuAtomicAdd(double *address, double val) {
269   return atomicAdd(address, val);
270 }
271 
gpuAtomicAdd(float * address,float val)272 inline __device__ float gpuAtomicAdd(float *address, float val) {
273   return atomicAdd(address, val);
274 }
275 
276 template<typename T>
gpuAtomicAdd(c10::complex<T> * address,c10::complex<T> val)277 inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
278   gpuAtomicAdd(&address->real_, val.real_);
279   gpuAtomicAdd(&address->imag_, val.imag_);
280 }
281 
282 /* Note [gpuAtomicAdd vs atomicAdd]
283  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
284  * Some extensions such as torchvision call atomicAdd()
285  * directly and require non-library provided data type support. Only for these, we
286  * continue to provide atomicAdd overloads.
287  */
atomicAdd(at::Half * address,at::Half val)288 inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
289   return gpuAtomicAdd(address, val);
290 }
291 
atomicAdd(at::BFloat16 * address,at::BFloat16 val)292 inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
293   return gpuAtomicAdd(address, val);
294 }
295 
atomicAdd(uint8_t * address,uint8_t val)296 inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
297   gpuAtomicAdd(address, val);
298 }
299 
atomicAdd(int8_t * address,int8_t val)300 inline  __device__ void atomicAdd(int8_t *address, int8_t val) {
301   gpuAtomicAdd(address, val);
302 }
303 
atomicAdd(int16_t * address,int16_t val)304 inline  __device__ void atomicAdd(int16_t *address, int16_t val) {
305   gpuAtomicAdd(address, val);
306 }
307 
atomicAdd(int64_t * address,int64_t val)308 inline __device__ void atomicAdd(int64_t *address, int64_t val) {
309   gpuAtomicAdd(address, val);
310 }
311 
atomicAdd(bool * address,bool val)312 inline __device__ void atomicAdd(bool *address, bool val) {
313   gpuAtomicAdd(address, val);
314 }
315 
316 /* Note [explicitly non-returning atomics]
317  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318  * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
319  * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
320  * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
321  * therefore we need a new API 'gpuAtomicAddNoReturn'.
322  */
323 template<typename T>
gpuAtomicAddNoReturn(c10::complex<T> * address,c10::complex<T> val)324 inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(uint8_t * address,uint8_t val)325 inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(int8_t * address,int8_t val)326 inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(int16_t * address,int16_t val)327 inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(int32_t * address,int32_t val)328 inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(int64_t * address,int64_t val)329 inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(bool * address,bool val)330 inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(at::Half * address,at::Half val)331 inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(at::BFloat16 * address,at::BFloat16 val)332 inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
gpuAtomicAddNoReturn(double * address,double val)333 inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
334 
335 /* Special case fp32 atomic. */
336 #if defined(USE_ROCM)
gpuAtomicAddNoReturn(float * address,float val)337 inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
338 #if defined(__gfx908__)
339   atomicAddNoRet(address, val);
340 #else
341   (void)unsafeAtomicAdd(address, val);
342 #endif
343 }
344 #else
gpuAtomicAddNoReturn(float * address,float val)345 inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
346 #endif
347 
348 // Atomic multiplication implementation.
349 
350 ATOMIC_INTEGER_IMPL(Mul)
GPU_ATOMIC_INTEGER(Mul,a * b,uint8_t)351 GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
352 GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
353 GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
354 GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
355 GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
356 
357 inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
358   return AtomicFPOp<at::Half>()(address, val,
359                                 [](at::Half bsum, at::Half val) {
360                                   return bsum * val;
361                                 });
362 }
363 
gpuAtomicMul(at::BFloat16 * address,at::BFloat16 val)364 inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
365   return AtomicFPOp<at::BFloat16>()(address, val,
366                                     [](at::BFloat16 bsum, at::BFloat16 val) {
367                                       return bsum * val;
368                                     });
369 }
370 
gpuAtomicMul(double * address,double val)371 inline __device__ double gpuAtomicMul(double * address, double val) {
372   return AtomicFPOp<double>()(address, val,
373                               [](double val, unsigned long long int assumed) {
374                                 return __double_as_longlong(val * __longlong_as_double(assumed));
375                               });
376 }
377 
378 // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
gpuAtomicMul(float * address,float val)379 inline __device__ float gpuAtomicMul (float * address, float val) {
380   unsigned int* address_as_ull = (unsigned int*)address;
381   unsigned int old = *address_as_ull;
382   unsigned int assumed;
383 
384   do {
385     assumed = old;
386     old = atomicCAS(address_as_ull, assumed,
387                     __float_as_int(val *
388                                    __int_as_float(assumed)));
389 
390     // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
391   } while (assumed != old);
392 
393   return __int_as_float(old);
394 }
395 
396 // Atomic maximum implementation.
397 
398 template <typename T>
safe_max(T a,T b)399 __host__ __device__ T safe_max(T a, T b) {
400   #if defined(__HIPCC__)
401   // TODO: remove this special case for HIP when issue is fixed:
402   //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
403     T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
404   #else
405     T max = at::_isnan(b) ? b : std::max<T>(a, b);
406   #endif
407 
408   return max;
409 }
410 
411 ATOMIC_INTEGER_IMPL(Max)
GPU_ATOMIC_INTEGER(Max,safe_max (a,b),uint8_t)412 GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
413 GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
414 GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
415 GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
416 GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
417 
418 inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
419   return AtomicFPOp<at::Half>()(address, val,
420                                 [](at::Half bsum, at::Half val) {
421                                   return safe_max(bsum, val);
422                                 });
423 }
424 
gpuAtomicMax(at::BFloat16 * address,at::BFloat16 val)425 inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
426   return AtomicFPOp<at::BFloat16>()(address, val,
427                                     [](at::BFloat16 bsum, at::BFloat16 val) {
428                                       return safe_max(bsum, val);
429                                     });
430 }
431 
gpuAtomicMax(double * address,double val)432 inline __device__ double gpuAtomicMax(double * address, double val) {
433   return AtomicFPOp<double>()(address, val,
434                               [](double val, unsigned long long int assumed) {
435                                 return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
436                               });
437 }
438 
439 // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
gpuAtomicMax(float * address,float val)440 inline __device__ float gpuAtomicMax(float * address, float val) {
441   unsigned int* address_as_ull = (unsigned int*)address;
442   unsigned int old = *address_as_ull;
443   unsigned int assumed;
444 
445   do {
446     assumed = old;
447     old = atomicCAS(address_as_ull, assumed,
448                     __float_as_int(safe_max(val, __int_as_float(assumed))));
449 
450     // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
451   } while (assumed != old);
452 
453   return __int_as_float(old);
454 }
455 
456 // Atomic minimum implementation.
457 
458 template <typename T>
safe_min(T a,T b)459 __host__ __device__ T safe_min(T a, T b) {
460   #if defined(__HIPCC__)
461   // TODO: remove this special case for HIP when issue is fixed:
462   //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
463     T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
464   #else
465     T min = at::_isnan(b) ? b : std::min<T>(a, b);
466   #endif
467 
468   return min;
469 }
470 
471 ATOMIC_INTEGER_IMPL(Min)
GPU_ATOMIC_INTEGER(Min,safe_min (a,b),uint8_t)472 GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
473 GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
474 GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
475 GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
476 GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
477 
478 inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
479   return AtomicFPOp<at::Half>()(address, val,
480                                 [](at::Half bsum, at::Half val) {
481                                   return safe_min(bsum, val);
482                                 });
483 }
484 
gpuAtomicMin(at::BFloat16 * address,at::BFloat16 val)485 inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
486   return AtomicFPOp<at::BFloat16>()(address, val,
487                                     [](at::BFloat16 bsum, at::BFloat16 val) {
488                                       return safe_min(bsum, val);
489                                     });
490 }
491 
gpuAtomicMin(double * address,double val)492 inline __device__ double gpuAtomicMin(double * address, double val) {
493   return AtomicFPOp<double>()(address, val,
494                               [](double val, unsigned long long int assumed) {
495                                 return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
496                               });
497 }
498 
499 // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
gpuAtomicMin(float * address,float val)500 inline __device__ float gpuAtomicMin(float * address, float val) {
501   unsigned int* address_as_ull = (unsigned int*)address;
502   unsigned int old = *address_as_ull;
503   unsigned int assumed;
504 
505   do {
506     assumed = old;
507     old = atomicCAS(address_as_ull, assumed,
508                     __float_as_int(safe_min(val, __int_as_float(assumed))));
509 
510     // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
511   } while (assumed != old);
512 
513   return __int_as_float(old);
514 }
515