xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/CUDATest.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include "CUDATest.hpp"
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/cuda/Exceptions.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker namespace c10d {
5*da0073e9SAndroid Build Coastguard Worker namespace test {
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace {
waitClocks(const uint64_t count)8*da0073e9SAndroid Build Coastguard Worker __global__ void waitClocks(const uint64_t count) {
9*da0073e9SAndroid Build Coastguard Worker   // Few AMD specific GPUs have different clock intrinsic
10*da0073e9SAndroid Build Coastguard Worker #if defined(__GFX11__) && defined(USE_ROCM) && !defined(__CUDA_ARCH__)
11*da0073e9SAndroid Build Coastguard Worker   clock_t start = wall_clock64();
12*da0073e9SAndroid Build Coastguard Worker #else
13*da0073e9SAndroid Build Coastguard Worker   clock_t start = clock64();
14*da0073e9SAndroid Build Coastguard Worker #endif
15*da0073e9SAndroid Build Coastguard Worker   clock_t offset = 0;
16*da0073e9SAndroid Build Coastguard Worker   while (offset < count) {
17*da0073e9SAndroid Build Coastguard Worker     offset = clock() - start;
18*da0073e9SAndroid Build Coastguard Worker   }
19*da0073e9SAndroid Build Coastguard Worker }
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker } // namespace
22*da0073e9SAndroid Build Coastguard Worker 
cudaSleep(at::cuda::CUDAStream & stream,uint64_t clocks)23*da0073e9SAndroid Build Coastguard Worker void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) {
24*da0073e9SAndroid Build Coastguard Worker   waitClocks<<<1, 1, 0, stream.stream()>>>(clocks);
25*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_KERNEL_LAUNCH_CHECK();
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker 
cudaNumDevices()28*da0073e9SAndroid Build Coastguard Worker int cudaNumDevices() {
29*da0073e9SAndroid Build Coastguard Worker   int n = 0;
30*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK_WARN(cudaGetDeviceCount(&n));
31*da0073e9SAndroid Build Coastguard Worker   return n;
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker } // namespace test
35*da0073e9SAndroid Build Coastguard Worker } // namespace c10d
36