xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CUDAJitLoops.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/jit_macros.h>
3 
4 // Jiterator functions are guarded behind this macro
5 #if AT_USE_JITERATOR()
6 
7 #include <ATen/OpMathType.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/core/Array.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/detail/OffsetCalculator.cuh>
12 #include <ATen/native/cuda/jit_utils.h>
13 #include <ATen/native/cuda/MemoryAccess.cuh>
14 #include <ATen/native/cuda/thread_constants.h>
15 
16 #include <ATen/native/cuda/Loops.cuh>
17 
18 #include <c10/macros/Macros.h>
19 #include <c10/core/ScalarType.h>
20 #include <c10/util/SmallBuffer.h>
21 
22 #include <initializer_list>
23 #include <type_traits>
24 #include <tuple>
25 #include <mutex>
26 
27 namespace at {
28 namespace native {
29 
30 template <typename Tuple, std::size_t... I>
tuple_to_array_helper(Tuple & t,std::index_sequence<I...> seq)31 constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
32     constexpr auto size = seq.size();
33     (void)t; // warning : unused parameter when tuple is empty.
34     return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
35 }
36 
37 // Helper function convert tuple to std::array<void*, N>
38 // for passing the arguments to CUDA Kernel
39 // NOTE: We capture tuple by reference,
40 // so the pointers in returned array are only valid
41 // till tuple is alive.
42 template <typename ...Args>
tuple_to_array(std::tuple<Args...> & extra_args)43 constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
44     constexpr auto tuple_size = sizeof...(Args);
45     return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
46 }
47 
48 struct JittedVecKernelCache {
49   // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
50   at::cuda::jit::NvrtcFunction vec1;
51   at::cuda::jit::NvrtcFunction vec2;
52   at::cuda::jit::NvrtcFunction vec4;
53 };
54 
55 struct JittedKernelVariantCache {
56   JittedVecKernelCache vec;
57   at::cuda::jit::NvrtcFunction noncontiguous;
58   at::cuda::jit::NvrtcFunction dynamic_contiguous;
59   at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
60 };
61 
pack_kernel_args(std::initializer_list<void * > args,c10::ArrayRef<void * > extra_args)62 inline c10::SmallBuffer<void*, 64> pack_kernel_args(
63     std::initializer_list<void*> args,
64     c10::ArrayRef<void*> extra_args) {
65   c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
66   std::copy(args.begin(), args.end(), ret.data());
67   std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
68   return ret;
69 }
70 
71 template<typename array_t,
72          typename inp_calc_t,
73          typename out_calc_t,
74          typename loader_t,
75          typename storer_t>
launch_jitted_unrolled_kernel(std::mutex & jiterator_mutex,at::cuda::jit::NvrtcFunction & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int64_t N,array_t data,inp_calc_t ic,out_calc_t oc,loader_t l,storer_t s,bool contiguous,at::cuda::jit::BinaryFuncVariant scalar_pos,void * scalar_val,c10::ArrayRef<void * > extra_args)76 void launch_jitted_unrolled_kernel(
77     std::mutex &jiterator_mutex,
78     at::cuda::jit::NvrtcFunction &fn_cache,
79     const at::cuda::jit::KernelDescriptor &desc,
80     int64_t N,
81     array_t data,
82     inp_calc_t ic,
83     out_calc_t oc,
84     loader_t l,
85     storer_t s,
86     bool contiguous,
87     at::cuda::jit::BinaryFuncVariant scalar_pos,
88     void* scalar_val,
89     c10::ArrayRef<void*> extra_args) {
90 
91   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
92   //casting result to int is always safe, intermediate is int64 and won't overflow
93   const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
94 
95   if (!fn_cache.function) {
96     const std::lock_guard<std::mutex> lock{jiterator_mutex};
97     if (!fn_cache.function) {
98       constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
99                                        !std::is_same<decltype(s), memory::StoreWithoutCast>();
100       auto code = at::cuda::jit::generate_code(
101           desc, contiguous, dynamic_casting, scalar_pos);
102       fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
103     }
104   }
105 
106   auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
107   at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
108   {num_threads(), 1u, 1u});
109 }
110 
111 template<int arity, typename array_t>
launch_jitted_vectorized_kernel(std::mutex & jiterator_mutex,JittedVecKernelCache & fn_cache,const at::cuda::jit::KernelDescriptor & desc,int64_t N,array_t data,at::cuda::jit::BinaryFuncVariant scalar_pos,void * scalar_val,c10::ArrayRef<void * > extra_args)112 void launch_jitted_vectorized_kernel(
113     std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
114     const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
115     at::cuda::jit::BinaryFuncVariant scalar_pos,
116     void *scalar_val, c10::ArrayRef<void*> extra_args) {
117   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
118   // N is still int64_t for the computation, but it's always safe to cast result to int
119   const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
120   const int vec_size = at::cuda::jit::can_vectorize_up_to(
121       desc, c10::ArrayRef<char*>(data.data, data.size()));
122 
123   // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
124   //   fn_ptr is set to the appropriate function based on the vec size and GPU used
125   at::cuda::jit::NvrtcFunction* fn_ptr;
126   if (vec_size == 4) {
127     fn_ptr = &fn_cache.vec4;
128   } else if (vec_size == 2) {
129     fn_ptr = &fn_cache.vec2;
130   } else if (vec_size ==1) {
131     fn_ptr = &fn_cache.vec1;
132   } else {
133     TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
134   }
135 
136   bool vectorized = vec_size > 1;
137 
138   if (!fn_ptr->function) {
139     const std::lock_guard<std::mutex> lock{jiterator_mutex};
140     if (!fn_ptr->function) { // cache miss!
141 
142       // Generates program
143       auto code = at::cuda::jit::generate_code(
144           desc, /*contiguous=*/true, /*dynamic_casting=*/false,
145           scalar_pos, vectorized, vec_size);
146       std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
147 
148       // Acquires the program
149       *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
150     }
151   }
152 
153   if (vectorized) {
154     auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
155     at::cuda::jit::launch_jitted_pwise_function(
156         *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
157   } else {
158 // NVCC complains about unused variables l and s.
159 // It should be false positive in most cases, so we suppress the warnings.
160 #pragma nv_diagnostic push
161 #pragma nv_diag_suppress 177
162     auto ic = TrivialOffsetCalculator<arity>();
163     auto oc = TrivialOffsetCalculator<1>();
164     auto l = memory::LoadWithoutCast();
165     auto s = memory::StoreWithoutCast();
166 
167     auto args = pack_kernel_args(
168         {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
169     at::cuda::jit::launch_jitted_pwise_function(
170         *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
171 #pragma nv_diagnostic pop
172   }
173 }
174 
175 template <int arity>
jitted_gpu_kernel_generic(std::mutex & jiterator_mutex,JittedKernelVariantCache & cache,const at::cuda::jit::KernelDescriptor & desc,at::cuda::jit::BinaryFuncVariant scalar_pos,c10::ArrayRef<void * > extra_args,TensorIteratorBase & iter,const bool dynamic_casting,void * scalar_val)176 void jitted_gpu_kernel_generic(
177     std::mutex &jiterator_mutex,
178     JittedKernelVariantCache &cache,
179     const at::cuda::jit::KernelDescriptor &desc,
180     at::cuda::jit::BinaryFuncVariant scalar_pos,
181     c10::ArrayRef<void*> extra_args,
182     TensorIteratorBase& iter,
183     const bool dynamic_casting,
184     void *scalar_val) {
185   TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
186   TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
187   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
188 
189   constexpr int ntensors = arity + 1;
190   at::detail::Array<char*, ntensors> data;
191   for (auto i : c10::irange(ntensors)) {
192     data[i] = (char*)iter.data_ptr(i);
193   }
194 
195   int64_t numel = iter.numel();
196   bool contiguous = iter.is_contiguous();
197 
198   // Decides which of 4 kernel types to launch
199   // Variations are:
200   //   - Case 1: no dynamic casting and contiguous
201   //   - Case 2: no dynamic casting and noncontiguous
202   //   - Case 3: dynamic casting and contiguous
203   //   - Case 4: dynamic casting and noncontiguous
204   // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
205 
206   if (!dynamic_casting) {
207     if (contiguous) {
208       // Case 1: no dynamic casting and contiguous
209       launch_jitted_vectorized_kernel<arity>(
210           jiterator_mutex, cache.vec, desc,
211           numel, data, scalar_pos, scalar_val, extra_args);
212       return;
213     }
214 
215     // Case 2: no dynamic casting and noncontiguous
216     auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
217     auto output_offset_calculator = make_output_offset_calculator(iter);
218     auto loader = memory::LoadWithoutCast();
219     auto storer = memory::StoreWithoutCast();
220     launch_jitted_unrolled_kernel(
221         jiterator_mutex, cache.noncontiguous, desc, numel, data,
222         input_offset_calculator, output_offset_calculator, loader,
223         storer, contiguous, scalar_pos, scalar_val, extra_args);
224     return;
225   }
226 
227   // Cases 3 and 4 are handled below
228   // Both require construction of a storer (this asserts 1 output) and one or more loaders
229 
230   // Creates store cast to output (the zeroth tensor in TensorIterator)
231   auto storer = memory::StoreWithCast<1>(iter);
232 
233   // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
234   auto loader = memory::LoadWithCast<arity>(iter);
235 
236   if (contiguous) {
237     // Case 3: dynamic casting and contiguous
238     auto input_offset_calculator = TrivialOffsetCalculator<arity>();
239     auto output_offset_calculator = TrivialOffsetCalculator<1>();
240     launch_jitted_unrolled_kernel(
241         jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
242         output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
243     return;
244   }
245 
246   // Case 4: dynamic casting and noncontiguous
247   auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
248   auto output_offset_calculator = make_output_offset_calculator(iter);
249   launch_jitted_unrolled_kernel(
250       jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
251       output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
252 }
253 
254 // NOTE: static to reduce chances of name collision.
255 template <
256     char const* name,
257     typename result_type,
258     typename f_inputs_type,
259     int arity,
260     at::cuda::jit::BinaryFuncVariant scalar_pos =
261         at::cuda::jit::BinaryFuncVariant::NoScalar,
262     typename... ExtraArgs>
jitted_gpu_kernel_impl(TensorIteratorBase & iter,const std::string & f,const bool dynamic_casting,at::opmath_type<f_inputs_type> scalar_val,std::tuple<ExtraArgs...> extra_args)263 static void jitted_gpu_kernel_impl(
264     TensorIteratorBase& iter,
265     const std::string &f,
266     const bool dynamic_casting,
267     at::opmath_type<f_inputs_type> scalar_val,
268     std::tuple<ExtraArgs...> extra_args) {
269 
270   // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
271   //   the same compute capability
272   static std::mutex jiterator_mutex;
273   static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
274 
275   constexpr int nInputs = arity;
276   constexpr int nOutputs = 1;  // TODO: Support more than 1 output
277   static const auto desc = at::cuda::jit::make_kernel_descriptor<
278     result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
279 
280   auto &cache = device_caches[iter.device().index()];
281   auto extra_args_array = tuple_to_array(extra_args);
282   return jitted_gpu_kernel_generic<arity>(
283       jiterator_mutex,
284       cache,
285       desc,
286       scalar_pos,
287       extra_args_array,
288       iter,
289       dynamic_casting,
290       &scalar_val
291     );
292 }
293 
294 }}  // at::native
295 
296 #endif // AT_USE_JITERATOR()
297