1 #pragma once
2 #include <ATen/jit_macros.h>
3
4 // Jiterator functions are guarded behind this macro
5 #if AT_USE_JITERATOR()
6
7 #include <ATen/OpMathType.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/core/Array.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/detail/OffsetCalculator.cuh>
12 #include <ATen/native/cuda/jit_utils.h>
13 #include <ATen/native/cuda/MemoryAccess.cuh>
14 #include <ATen/native/cuda/thread_constants.h>
15
16 #include <ATen/native/cuda/Loops.cuh>
17
18 #include <c10/macros/Macros.h>
19 #include <c10/core/ScalarType.h>
20 #include <c10/util/SmallBuffer.h>
21
22 #include <initializer_list>
23 #include <type_traits>
24 #include <tuple>
25 #include <mutex>
26
27 namespace at {
28 namespace native {
29
30 template <typename Tuple, std::size_t... I>
tuple_to_array_helper(Tuple & t,std::index_sequence<I...> seq)31 constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
32 constexpr auto size = seq.size();
33 (void)t; // warning : unused parameter when tuple is empty.
34 return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
35 }
36
37 // Helper function convert tuple to std::array<void*, N>
38 // for passing the arguments to CUDA Kernel
39 // NOTE: We capture tuple by reference,
40 // so the pointers in returned array are only valid
41 // till tuple is alive.
42 template <typename ...Args>
tuple_to_array(std::tuple<Args...> & extra_args)43 constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
44 constexpr auto tuple_size = sizeof...(Args);
45 return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
46 }
47
48 struct JittedVecKernelCache {
49 // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
50 at::cuda::jit::NvrtcFunction vec1;
51 at::cuda::jit::NvrtcFunction vec2;
52 at::cuda::jit::NvrtcFunction vec4;
53 };
54
55 struct JittedKernelVariantCache {
56 JittedVecKernelCache vec;
57 at::cuda::jit::NvrtcFunction noncontiguous;
58 at::cuda::jit::NvrtcFunction dynamic_contiguous;
59 at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
60 };
61
pack_kernel_args(std::initializer_list<void * > args,c10::ArrayRef<void * > extra_args)62 inline c10::SmallBuffer<void*, 64> pack_kernel_args(
63 std::initializer_list<void*> args,
64 c10::ArrayRef<void*> extra_args) {
65 c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
66 std::copy(args.begin(), args.end(), ret.data());
67 std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
68 return ret;
69 }
70
71 template<typename array_t,
72 typename inp_calc_t,
73 typename out_calc_t,
74 typename loader_t,
75 typename storer_t>
launch_jitted_unrolled_kernel(std::mutex & jiterator_mutex,at::cuda::jit::NvrtcFunction & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int64_t N,array_t data,inp_calc_t ic,out_calc_t oc,loader_t l,storer_t s,bool contiguous,at::cuda::jit::BinaryFuncVariant scalar_pos,void * scalar_val,c10::ArrayRef<void * > extra_args)76 void launch_jitted_unrolled_kernel(
77 std::mutex &jiterator_mutex,
78 at::cuda::jit::NvrtcFunction &fn_cache,
79 const at::cuda::jit::KernelDescriptor &desc,
80 int64_t N,
81 array_t data,
82 inp_calc_t ic,
83 out_calc_t oc,
84 loader_t l,
85 storer_t s,
86 bool contiguous,
87 at::cuda::jit::BinaryFuncVariant scalar_pos,
88 void* scalar_val,
89 c10::ArrayRef<void*> extra_args) {
90
91 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
92 //casting result to int is always safe, intermediate is int64 and won't overflow
93 const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
94
95 if (!fn_cache.function) {
96 const std::lock_guard<std::mutex> lock{jiterator_mutex};
97 if (!fn_cache.function) {
98 constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
99 !std::is_same<decltype(s), memory::StoreWithoutCast>();
100 auto code = at::cuda::jit::generate_code(
101 desc, contiguous, dynamic_casting, scalar_pos);
102 fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
103 }
104 }
105
106 auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
107 at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
108 {num_threads(), 1u, 1u});
109 }
110
111 template<int arity, typename array_t>
launch_jitted_vectorized_kernel(std::mutex & jiterator_mutex,JittedVecKernelCache & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int64_t N,array_t data,at::cuda::jit::BinaryFuncVariant scalar_pos,void * scalar_val,c10::ArrayRef<void * > extra_args)112 void launch_jitted_vectorized_kernel(
113 std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
114 const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
115 at::cuda::jit::BinaryFuncVariant scalar_pos,
116 void *scalar_val, c10::ArrayRef<void*> extra_args) {
117 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
118 // N is still int64_t for the computation, but it's always safe to cast result to int
119 const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
120 const int vec_size = at::cuda::jit::can_vectorize_up_to(
121 desc, c10::ArrayRef<char*>(data.data, data.size()));
122
123 // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
124 // fn_ptr is set to the appropriate function based on the vec size and GPU used
125 at::cuda::jit::NvrtcFunction* fn_ptr;
126 if (vec_size == 4) {
127 fn_ptr = &fn_cache.vec4;
128 } else if (vec_size == 2) {
129 fn_ptr = &fn_cache.vec2;
130 } else if (vec_size ==1) {
131 fn_ptr = &fn_cache.vec1;
132 } else {
133 TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
134 }
135
136 bool vectorized = vec_size > 1;
137
138 if (!fn_ptr->function) {
139 const std::lock_guard<std::mutex> lock{jiterator_mutex};
140 if (!fn_ptr->function) { // cache miss!
141
142 // Generates program
143 auto code = at::cuda::jit::generate_code(
144 desc, /*contiguous=*/true, /*dynamic_casting=*/false,
145 scalar_pos, vectorized, vec_size);
146 std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
147
148 // Acquires the program
149 *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
150 }
151 }
152
153 if (vectorized) {
154 auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
155 at::cuda::jit::launch_jitted_pwise_function(
156 *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
157 } else {
158 // NVCC complains about unused variables l and s.
159 // It should be false positive in most cases, so we suppress the warnings.
160 #pragma nv_diagnostic push
161 #pragma nv_diag_suppress 177
162 auto ic = TrivialOffsetCalculator<arity>();
163 auto oc = TrivialOffsetCalculator<1>();
164 auto l = memory::LoadWithoutCast();
165 auto s = memory::StoreWithoutCast();
166
167 auto args = pack_kernel_args(
168 {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
169 at::cuda::jit::launch_jitted_pwise_function(
170 *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
171 #pragma nv_diagnostic pop
172 }
173 }
174
175 template <int arity>
jitted_gpu_kernel_generic(std::mutex & jiterator_mutex,JittedKernelVariantCache & cache,const at::cuda::jit::KernelDescriptor & desc,at::cuda::jit::BinaryFuncVariant scalar_pos,c10::ArrayRef<void * > extra_args,TensorIteratorBase & iter,const bool dynamic_casting,void * scalar_val)176 void jitted_gpu_kernel_generic(
177 std::mutex &jiterator_mutex,
178 JittedKernelVariantCache &cache,
179 const at::cuda::jit::KernelDescriptor &desc,
180 at::cuda::jit::BinaryFuncVariant scalar_pos,
181 c10::ArrayRef<void*> extra_args,
182 TensorIteratorBase& iter,
183 const bool dynamic_casting,
184 void *scalar_val) {
185 TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
186 TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
187 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
188
189 constexpr int ntensors = arity + 1;
190 at::detail::Array<char*, ntensors> data;
191 for (auto i : c10::irange(ntensors)) {
192 data[i] = (char*)iter.data_ptr(i);
193 }
194
195 int64_t numel = iter.numel();
196 bool contiguous = iter.is_contiguous();
197
198 // Decides which of 4 kernel types to launch
199 // Variations are:
200 // - Case 1: no dynamic casting and contiguous
201 // - Case 2: no dynamic casting and noncontiguous
202 // - Case 3: dynamic casting and contiguous
203 // - Case 4: dynamic casting and noncontiguous
204 // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
205
206 if (!dynamic_casting) {
207 if (contiguous) {
208 // Case 1: no dynamic casting and contiguous
209 launch_jitted_vectorized_kernel<arity>(
210 jiterator_mutex, cache.vec, desc,
211 numel, data, scalar_pos, scalar_val, extra_args);
212 return;
213 }
214
215 // Case 2: no dynamic casting and noncontiguous
216 auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
217 auto output_offset_calculator = make_output_offset_calculator(iter);
218 auto loader = memory::LoadWithoutCast();
219 auto storer = memory::StoreWithoutCast();
220 launch_jitted_unrolled_kernel(
221 jiterator_mutex, cache.noncontiguous, desc, numel, data,
222 input_offset_calculator, output_offset_calculator, loader,
223 storer, contiguous, scalar_pos, scalar_val, extra_args);
224 return;
225 }
226
227 // Cases 3 and 4 are handled below
228 // Both require construction of a storer (this asserts 1 output) and one or more loaders
229
230 // Creates store cast to output (the zeroth tensor in TensorIterator)
231 auto storer = memory::StoreWithCast<1>(iter);
232
233 // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
234 auto loader = memory::LoadWithCast<arity>(iter);
235
236 if (contiguous) {
237 // Case 3: dynamic casting and contiguous
238 auto input_offset_calculator = TrivialOffsetCalculator<arity>();
239 auto output_offset_calculator = TrivialOffsetCalculator<1>();
240 launch_jitted_unrolled_kernel(
241 jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
242 output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
243 return;
244 }
245
246 // Case 4: dynamic casting and noncontiguous
247 auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
248 auto output_offset_calculator = make_output_offset_calculator(iter);
249 launch_jitted_unrolled_kernel(
250 jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
251 output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
252 }
253
254 // NOTE: static to reduce chances of name collision.
255 template <
256 char const* name,
257 typename result_type,
258 typename f_inputs_type,
259 int arity,
260 at::cuda::jit::BinaryFuncVariant scalar_pos =
261 at::cuda::jit::BinaryFuncVariant::NoScalar,
262 typename... ExtraArgs>
jitted_gpu_kernel_impl(TensorIteratorBase & iter,const std::string & f,const bool dynamic_casting,at::opmath_type<f_inputs_type> scalar_val,std::tuple<ExtraArgs...> extra_args)263 static void jitted_gpu_kernel_impl(
264 TensorIteratorBase& iter,
265 const std::string &f,
266 const bool dynamic_casting,
267 at::opmath_type<f_inputs_type> scalar_val,
268 std::tuple<ExtraArgs...> extra_args) {
269
270 // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
271 // the same compute capability
272 static std::mutex jiterator_mutex;
273 static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
274
275 constexpr int nInputs = arity;
276 constexpr int nOutputs = 1; // TODO: Support more than 1 output
277 static const auto desc = at::cuda::jit::make_kernel_descriptor<
278 result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
279
280 auto &cache = device_caches[iter.device().index()];
281 auto extra_args_array = tuple_to_array(extra_args);
282 return jitted_gpu_kernel_generic<arity>(
283 jiterator_mutex,
284 cache,
285 desc,
286 scalar_pos,
287 extra_args_array,
288 iter,
289 dynamic_casting,
290 &scalar_val
291 );
292 }
293
294 }} // at::native
295
296 #endif // AT_USE_JITERATOR()
297