1 #pragma once 2 3 #if !defined(USE_ROCM) 4 #include <cuda.h> // for CUDA_VERSION 5 #endif 6 7 #if !defined(USE_ROCM) 8 #include <cub/version.cuh> 9 #else 10 #define CUB_VERSION 0 11 #endif 12 13 // cub sort support for __nv_bfloat16 is added to cub 1.13 in: 14 // https://github.com/NVIDIA/cub/pull/306 15 #if CUB_VERSION >= 101300 16 #define CUB_SUPPORTS_NV_BFLOAT16() true 17 #else 18 #define CUB_SUPPORTS_NV_BFLOAT16() false 19 #endif 20 21 // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: 22 // https://github.com/NVIDIA/cub/pull/326 23 // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake 24 // starting from CUDA 11.5 25 #if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) 26 #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true 27 #else 28 #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false 29 #endif 30 31 // cub support for UniqueByKey is added to cub 1.16 in: 32 // https://github.com/NVIDIA/cub/pull/405 33 #if CUB_VERSION >= 101600 34 #define CUB_SUPPORTS_UNIQUE_BY_KEY() true 35 #else 36 #define CUB_SUPPORTS_UNIQUE_BY_KEY() false 37 #endif 38 39 // cub support for scan by key is added to cub 1.15 40 // in https://github.com/NVIDIA/cub/pull/376 41 #if CUB_VERSION >= 101500 42 #define CUB_SUPPORTS_SCAN_BY_KEY() 1 43 #else 44 #define CUB_SUPPORTS_SCAN_BY_KEY() 0 45 #endif 46 47 // cub support for cub::FutureValue is added to cub 1.15 in: 48 // https://github.com/NVIDIA/cub/pull/305 49 #if CUB_VERSION >= 101500 50 #define CUB_SUPPORTS_FUTURE_VALUE() true 51 #else 52 #define CUB_SUPPORTS_FUTURE_VALUE() false 53 #endif 54