xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/cub_definitions.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #if !defined(USE_ROCM)
4 #include <cuda.h>  // for CUDA_VERSION
5 #endif
6 
7 #if !defined(USE_ROCM)
8 #include <cub/version.cuh>
9 #else
10 #define CUB_VERSION 0
11 #endif
12 
13 // cub sort support for __nv_bfloat16 is added to cub 1.13 in:
14 // https://github.com/NVIDIA/cub/pull/306
15 #if CUB_VERSION >= 101300
16 #define CUB_SUPPORTS_NV_BFLOAT16() true
17 #else
18 #define CUB_SUPPORTS_NV_BFLOAT16() false
19 #endif
20 
21 // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
22 // https://github.com/NVIDIA/cub/pull/326
23 // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
24 // starting from CUDA 11.5
25 #if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
26 #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
27 #else
28 #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
29 #endif
30 
31 // cub support for UniqueByKey is added to cub 1.16 in:
32 // https://github.com/NVIDIA/cub/pull/405
33 #if CUB_VERSION >= 101600
34 #define CUB_SUPPORTS_UNIQUE_BY_KEY() true
35 #else
36 #define CUB_SUPPORTS_UNIQUE_BY_KEY() false
37 #endif
38 
39 // cub support for scan by key is added to cub 1.15
40 // in https://github.com/NVIDIA/cub/pull/376
41 #if CUB_VERSION >= 101500
42 #define CUB_SUPPORTS_SCAN_BY_KEY() 1
43 #else
44 #define CUB_SUPPORTS_SCAN_BY_KEY() 0
45 #endif
46 
47 // cub support for cub::FutureValue is added to cub 1.15 in:
48 // https://github.com/NVIDIA/cub/pull/305
49 #if CUB_VERSION >= 101500
50 #define CUB_SUPPORTS_FUTURE_VALUE() true
51 #else
52 #define CUB_SUPPORTS_FUTURE_VALUE() false
53 #endif
54