xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/CUDAContext.h>
4 
5 namespace at::cuda {
6 
7 // Check if every tensor in a list of tensors matches the current
8 // device.
check_device(ArrayRef<Tensor> ts)9 inline bool check_device(ArrayRef<Tensor> ts) {
10   if (ts.empty()) {
11     return true;
12   }
13   Device curDevice = Device(kCUDA, current_device());
14   for (const Tensor& t : ts) {
15     if (t.device() != curDevice) return false;
16   }
17   return true;
18 }
19 
20 } // namespace at::cuda
21