xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DeviceSqrt.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace at { namespace native {
4 #if defined(USE_ROCM)
5 // take these out when ROCm implements std:: math functions
6 #include <math.h>
7 template <typename scalar_t>
8 static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
9 
10 template <>
device_sqrt(float val)11 __forceinline__ __device__ float device_sqrt(float val) {
12   return ::sqrtf(val);
13 }
14 
15 template <>
device_sqrt(double val)16 __forceinline__ __device__ double device_sqrt(double val) {
17   return ::sqrt(val);
18 }
19 #else
20 template<typename scalar_t>
21 __forceinline__ __device__ double device_sqrt(scalar_t val) {
22   return std::sqrt(val);
23 }
24 #endif
25 }}
26