1 #pragma once 2 3 #include <ATen/code_template.h> 4 5 namespace torch { 6 namespace jit { 7 namespace fuser { 8 namespace cpu { 9 10 /*with type_as not checking type of its input, a fusion group can have non-fp32 11 tensor as input. Correct code for this case is generated, however, nvrtc does 12 not know how to handle int*_t integer types, so typedefs help it handle those 13 cases*/ 14 15 static auto type_declarations_template = at::jit::CodeTemplate(R"( 16 17 #define POS_INFINITY INFINITY 18 #define NEG_INFINITY -INFINITY 19 20 typedef ${IndexType} IndexType; 21 template<typename T, size_t N> 22 struct TensorInfo { 23 T* data; 24 IndexType sizes[N]; 25 IndexType strides[N]; 26 }; 27 template<typename T> 28 struct TensorInfo<T, 0> { 29 T * data; 30 }; 31 )"); 32 33 static auto cpu_compilation_unit_template = at::jit::CodeTemplate(R"( 34 #include <math.h> 35 #include <cstddef> 36 #include <cstdint> 37 38 double rsqrt(double x) { 39 return 1.0/sqrt(x); 40 } 41 42 float rsqrtf(float x) { 43 return 1.0f/sqrtf(x); 44 } 45 46 double frac(double x) { 47 return x - trunc(x); 48 } 49 50 float fracf(float x) { 51 return x - truncf(x); 52 } 53 54 ${type_declarations} 55 56 #ifdef _MSC_VER 57 template<size_t n> struct int_of_size; 58 59 #define DEFINE_INT_OF_SIZE(int_t) \ 60 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; } 61 62 DEFINE_INT_OF_SIZE(int64_t); 63 DEFINE_INT_OF_SIZE(int32_t); 64 DEFINE_INT_OF_SIZE(int16_t); 65 DEFINE_INT_OF_SIZE(int8_t); 66 67 #undef DEFINE_INT_OF_SIZE 68 69 template <typename T> 70 using int_same_size_t = typename int_of_size<sizeof(T)>::type; 71 72 #define IndexTypeLoop int_same_size_t<IndexType> 73 #define ToIndexTypeLoop(x) static_cast<IndexTypeLoop>(x) 74 #else 75 #define IndexTypeLoop IndexType 76 #define ToIndexTypeLoop(x) x 77 #endif 78 79 #define OMP_THRESHOLD 100000 80 static void ${kernelName}_kernel(IndexType totalElements, ${formals}) { 81 #pragma omp parallel for if(totalElements > OMP_THRESHOLD) 82 for (IndexTypeLoop linearIndex = 0; 83 linearIndex < ToIndexTypeLoop(totalElements); 84 linearIndex += 1) { 85 // Convert `linearIndex` into an offset of tensor: 86 ${tensorOffsets} 87 // calculate the results 88 ${kernelBody} 89 } 90 } 91 92 #ifdef _WIN32 93 #define JIT_API __declspec(dllexport) 94 #else 95 #define JIT_API 96 #endif 97 98 extern "C" 99 JIT_API void ${kernelName}(IndexType totalElements, void ** args) { 100 ${kernelName}_kernel(totalElements ${,argument_loads}); 101 } 102 )"); 103 104 } // namespace cpu 105 } // namespace fuser 106 } // namespace jit 107 } // namespace torch 108