1 #pragma once 2 3 #include <ATen/cuda/CUDAContext.h> 4 5 namespace at::cuda { 6 7 // Check if every tensor in a list of tensors matches the current 8 // device. check_device(ArrayRef<Tensor> ts)9inline bool check_device(ArrayRef<Tensor> ts) { 10 if (ts.empty()) { 11 return true; 12 } 13 Device curDevice = Device(kCUDA, current_device()); 14 for (const Tensor& t : ts) { 15 if (t.device() != curDevice) return false; 16 } 17 return true; 18 } 19 20 } // namespace at::cuda 21