1 #pragma once
2
3 #include <ATen/jit_macros.h>
4
5 #if AT_USE_JITERATOR()
6
7 #include <ATen/cuda/CUDAConfig.h>
8
9 #include <ATen/OpMathType.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/native/TensorIteratorDynamicCasting.h>
12
13 #include <ATen/native/cuda/MemoryAccess.cuh>
14
15 #include <ATen/native/cuda/CUDAJitLoops.cuh>
16
17 namespace at {
18 namespace native {
19
20 /* Note [Jiterator]
21 The "jiterator" simply just-in-time compiles the same kernels that
22 Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
23 build size, and initial CUDA context size.
24
25 By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
26 This behavior is controlled with two environment variables:
27 - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
28 - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
29
30 The jiterator currently has some limitations, however. It cannot:
31 - handle math on complex datatypes
32 - handle kernels with scalar parameters
33
34 These improvements will likely come soon.
35
36 For examples of how to use the jiterator see the i1 and gcd kernel
37 implementations, which pass jittable strings implementing their
38 operations instead of the typical CUDA functors.
39
40 To pass a runtime argument (similar to lambda captures in non-JIT kernels),
41 we need to pass to additional arguments to `jitted_gpu_kernel` by value.
42 Currently only primitive C++ types used for computation are valid.
43 The order of these extra arguments should be same as the order they appear
44 in kernel's function signature. (look at polygamma for example)
45
46 NOTE: One big restriction being that these arguments should be after the
47 arguments provided by TensorIterator. Eg. While capturing `n`, where
48 `scalar_t x` and `scalar_t y` are provided by TensorIterator,
49 * foo(scalar_t x, scalar_t y, int n) works!
50 * foo(int n, scalar_t x, scalar_y) doesn't work
51 * foo(scalar_t x, int n, scalar_y) doesn't work
52
53 */
54
55 // Entrypoint for jitted GPU kernels.
56 // Only handles elementwise unary and binary kernels with a
57 // common dtype and a single output.
58 // NOTE: this assumes the op's iterator has a common_dtype.
59 // NOTE: We use std::tuple instead of parameter pack
60 // for `extra_args` due to following
61 // bug on older versions of clang
62 // https://bugs.llvm.org/show_bug.cgi?id=23029
63 template <
64 char const* name,
65 typename return_type,
66 typename f_inputs_type,
67 int arity,
68 typename... Args>
jitted_gpu_kernel(TensorIteratorBase & iter,const std::string & f,at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar,at::opmath_type<f_inputs_type> scalar_val=0,std::tuple<Args...> extra_args=std::make_tuple ())69 void jitted_gpu_kernel(
70 TensorIteratorBase& iter,
71 const std::string& f,
72 at::cuda::jit::BinaryFuncVariant scalar_pos =
73 at::cuda::jit::BinaryFuncVariant::NoScalar,
74 at::opmath_type<f_inputs_type> scalar_val = 0,
75 std::tuple<Args...> extra_args = std::make_tuple()) {
76 // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
77 // Maybe it could be refactored?
78 for (int arg = 0; arg < iter.ntensors(); arg++) {
79 TORCH_INTERNAL_ASSERT(
80 iter.device(arg).is_cuda(),
81 "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
82 }
83
84 if (iter.numel() == 0) {
85 return;
86 }
87
88 if (!iter.can_use_32bit_indexing()) {
89 for (auto& sub_iter : iter.with_32bit_indexing()) {
90 jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
91 sub_iter, f, scalar_pos, scalar_val, extra_args);
92 }
93
94 return;
95 }
96
97 // Computes if dynamic casting is needed
98 // Dynamic casting is needed if an input's dtype differs from the common dtype
99 // or if the result dtype differs from the output's dtype
100 // Note: this is intentionally divergent from calling needs_dynamic_casting,
101 // which is more general and inspects a lambda to determine if dynamic
102 // casting is needed.
103 bool needs_dynamic_casting = false;
104
105 // Checks output
106 const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
107 const auto dtype0 = iter.dtype(0);
108 if (dtype0 != return_scalar_type) {
109 needs_dynamic_casting = true;
110 }
111
112 // Checks input(s)
113 const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
114 for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
115 const auto dtypei = iter.dtype(i);
116 if (dtypei != inputs_scalar_type) {
117 needs_dynamic_casting = true;
118 break;
119 }
120 }
121 if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
122 // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
123 // for computation in the generated code and hence we pass a dummy
124 // value of `0`.
125 jitted_gpu_kernel_impl<
126 /*name*/ name,
127 /*return_type=*/return_type,
128 /*f_inputs_type=*/f_inputs_type,
129 arity,
130 at::cuda::jit::BinaryFuncVariant::NoScalar>(
131 iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
132 } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
133 jitted_gpu_kernel_impl<
134 /*name*/ name,
135 /*return_type=*/return_type,
136 /*f_inputs_type=*/f_inputs_type,
137 arity,
138 at::cuda::jit::BinaryFuncVariant::RhsScalar>(
139 iter,
140 f,
141 needs_dynamic_casting,
142 scalar_val,
143 extra_args);
144
145 } else {
146 jitted_gpu_kernel_impl<
147 /*name*/ name,
148 /*return_type=*/return_type,
149 /*f_inputs_type=*/f_inputs_type,
150 arity,
151 at::cuda::jit::BinaryFuncVariant::LhsScalar>(
152 iter,
153 f,
154 needs_dynamic_casting,
155 scalar_val,
156 extra_args);
157 }
158 }
159
160 // TODO: support runtime state capture similar to `jitted_gpu_kernel`.
161 template <char const *name, typename return_type, typename f_inputs_type>
opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase & iter,const std::string & f)162 void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
163 TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
164 //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
165 using opmath_t = at::opmath_type<f_inputs_type>;
166 if (iter.is_cpu_scalar(1)) {
167 auto scalar_val = iter.scalar_value<opmath_t>(1);
168 iter.remove_operand(1);
169 // TODO: When all kernels that use gpu_kernel_with_scalars are
170 // ported to structured, this device guard can be deleted. This
171 // works around incorrect device guard generation for pre-structured
172 // kernels device guards, but structured kernels do it right and
173 // we can assume the device is already set correctly
174 const OptionalDeviceGuard device_guard(iter.device(1));
175 jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
176 } else if (iter.is_cpu_scalar(2)) {
177 auto scalar_val = iter.scalar_value<opmath_t>(2);
178 iter.remove_operand(2);
179 jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
180 } else {
181 jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
182 }
183 }
184
185 }} // at::native
186
187 #endif // AT_USE_JITERATOR()
188