1 #pragma once 2 3 namespace at { namespace native { 4 #if defined(USE_ROCM) 5 // take these out when ROCm implements std:: math functions 6 #include <math.h> 7 template <typename scalar_t> 8 static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); 9 10 template <> device_sqrt(float val)11__forceinline__ __device__ float device_sqrt(float val) { 12 return ::sqrtf(val); 13 } 14 15 template <> device_sqrt(double val)16__forceinline__ __device__ double device_sqrt(double val) { 17 return ::sqrt(val); 18 } 19 #else 20 template<typename scalar_t> 21 __forceinline__ __device__ double device_sqrt(scalar_t val) { 22 return std::sqrt(val); 23 } 24 #endif 25 }} 26