xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cpu/resource_strings.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/code_template.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace fuser {
8 namespace cpu {
9 
10 /*with type_as not checking type of its input, a fusion group can have non-fp32
11 tensor as input. Correct code for this case is generated, however, nvrtc does
12 not know how to handle int*_t integer types, so typedefs help it handle those
13 cases*/
14 
15 static auto type_declarations_template = at::jit::CodeTemplate(R"(
16 
17 #define POS_INFINITY INFINITY
18 #define NEG_INFINITY -INFINITY
19 
20 typedef ${IndexType} IndexType;
21 template<typename T, size_t N>
22 struct TensorInfo {
23   T* data;
24   IndexType sizes[N];
25   IndexType strides[N];
26 };
27 template<typename T>
28 struct TensorInfo<T, 0> {
29   T * data;
30 };
31 )");
32 
33 static auto cpu_compilation_unit_template = at::jit::CodeTemplate(R"(
34 #include <math.h>
35 #include <cstddef>
36 #include <cstdint>
37 
38 double rsqrt(double x) {
39   return 1.0/sqrt(x);
40 }
41 
42 float rsqrtf(float x) {
43   return 1.0f/sqrtf(x);
44 }
45 
46 double frac(double x) {
47   return x - trunc(x);
48 }
49 
50 float fracf(float x) {
51   return x - truncf(x);
52 }
53 
54 ${type_declarations}
55 
56 #ifdef _MSC_VER
57 template<size_t n> struct int_of_size;
58 
59 #define DEFINE_INT_OF_SIZE(int_t) \
60 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
61 
62 DEFINE_INT_OF_SIZE(int64_t);
63 DEFINE_INT_OF_SIZE(int32_t);
64 DEFINE_INT_OF_SIZE(int16_t);
65 DEFINE_INT_OF_SIZE(int8_t);
66 
67 #undef DEFINE_INT_OF_SIZE
68 
69 template <typename T>
70 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
71 
72 #define IndexTypeLoop int_same_size_t<IndexType>
73 #define ToIndexTypeLoop(x) static_cast<IndexTypeLoop>(x)
74 #else
75 #define IndexTypeLoop IndexType
76 #define ToIndexTypeLoop(x) x
77 #endif
78 
79 #define OMP_THRESHOLD 100000
80 static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
81   #pragma omp parallel for if(totalElements > OMP_THRESHOLD)
82   for (IndexTypeLoop linearIndex = 0;
83         linearIndex < ToIndexTypeLoop(totalElements);
84         linearIndex += 1) {
85       // Convert `linearIndex` into an offset of tensor:
86       ${tensorOffsets}
87       // calculate the results
88       ${kernelBody}
89     }
90 }
91 
92 #ifdef _WIN32
93 #define JIT_API __declspec(dllexport)
94 #else
95 #define JIT_API
96 #endif
97 
98 extern "C"
99 JIT_API void ${kernelName}(IndexType totalElements, void ** args) {
100   ${kernelName}_kernel(totalElements ${,argument_loads});
101 }
102 )");
103 
104 } // namespace cpu
105 } // namespace fuser
106 } // namespace jit
107 } // namespace torch
108