1*da0073e9SAndroid Build Coastguard Worker #include "CUDATest.hpp" 2*da0073e9SAndroid Build Coastguard Worker #include <ATen/cuda/Exceptions.h> 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker namespace c10d { 5*da0073e9SAndroid Build Coastguard Worker namespace test { 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker namespace { waitClocks(const uint64_t count)8*da0073e9SAndroid Build Coastguard Worker__global__ void waitClocks(const uint64_t count) { 9*da0073e9SAndroid Build Coastguard Worker // Few AMD specific GPUs have different clock intrinsic 10*da0073e9SAndroid Build Coastguard Worker #if defined(__GFX11__) && defined(USE_ROCM) && !defined(__CUDA_ARCH__) 11*da0073e9SAndroid Build Coastguard Worker clock_t start = wall_clock64(); 12*da0073e9SAndroid Build Coastguard Worker #else 13*da0073e9SAndroid Build Coastguard Worker clock_t start = clock64(); 14*da0073e9SAndroid Build Coastguard Worker #endif 15*da0073e9SAndroid Build Coastguard Worker clock_t offset = 0; 16*da0073e9SAndroid Build Coastguard Worker while (offset < count) { 17*da0073e9SAndroid Build Coastguard Worker offset = clock() - start; 18*da0073e9SAndroid Build Coastguard Worker } 19*da0073e9SAndroid Build Coastguard Worker } 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker } // namespace 22*da0073e9SAndroid Build Coastguard Worker cudaSleep(at::cuda::CUDAStream & stream,uint64_t clocks)23*da0073e9SAndroid Build Coastguard Workervoid cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) { 24*da0073e9SAndroid Build Coastguard Worker waitClocks<<<1, 1, 0, stream.stream()>>>(clocks); 25*da0073e9SAndroid Build Coastguard Worker C10_CUDA_KERNEL_LAUNCH_CHECK(); 26*da0073e9SAndroid Build Coastguard Worker } 27*da0073e9SAndroid Build Coastguard Worker cudaNumDevices()28*da0073e9SAndroid Build Coastguard Workerint cudaNumDevices() { 29*da0073e9SAndroid Build Coastguard Worker int n = 0; 30*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(cudaGetDeviceCount(&n)); 31*da0073e9SAndroid Build Coastguard Worker return n; 32*da0073e9SAndroid Build Coastguard Worker } 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker } // namespace test 35*da0073e9SAndroid Build Coastguard Worker } // namespace c10d 36