1 #pragma once 2 3 #include <complex> 4 #include <type_traits> 5 #include <c10/core/ScalarType.h> 6 #include <ATen/detail/FunctionTraits.h> 7 #include <ATen/native/TensorIterator.h> 8 9 10 // This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h. 11 12 // dynamic_casting handles when the types expected by the iterator do not match the types of the arguments 13 // to the function that is being called. 14 // On CUDA, the cast is currently pushed down into the kernel (for performance reasons). 15 // On CPU, there is currently an internal assert that a dynamic_cast is not needed. 16 17 namespace at::native { 18 19 // `needs_dynamic_casting` compares the types expected by iterator 20 // (i.e. dtypes of the operands) with the actual type of the arguments 21 // (and returns) of func_t 22 template<typename func_t, int nargs=function_traits<func_t>::arity> 23 struct needs_dynamic_casting { checkneeds_dynamic_casting24 static bool check(TensorIteratorBase& iter) { 25 using traits = function_traits<func_t>; 26 using cpp_type = typename traits::template arg<nargs - 1>::type; 27 using cpp_map = c10::CppTypeToScalarType<cpp_type>; 28 29 if (iter.input_dtype(nargs-1) != cpp_map::value) { 30 return true; 31 } 32 return needs_dynamic_casting<func_t, nargs - 1>::check(iter); 33 } 34 }; 35 36 template<typename func_t> 37 struct needs_dynamic_casting<func_t, 0> { 38 static bool check(TensorIteratorBase& iter) { 39 using traits = function_traits<func_t>; 40 using cpp_type = typename traits::result_type; 41 42 // we could assert output numbers are correct here, but checks 43 // (including arity) are currently pushed outside of this struct. 44 if constexpr (std::is_void_v<cpp_type>) { 45 return false; 46 } else { 47 return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value; 48 } 49 } 50 }; 51 52 } //namespace at::native 53