xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MultiTensorApply.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <c10/cuda/CUDAGuard.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/MemoryAccess.cuh>
7 #include <vector>
8 
9 namespace at::native {
10 
11 namespace {
12 
13 static constexpr int64_t kILP = 4;
14 static constexpr int64_t kChunkSize = 65536;
15 static constexpr int64_t kBlockSize = 512;
16 
17 // TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
18 // TensorListMetadata has to be < 4KB - the limit for kernel launch argument
19 static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
20 static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
21 static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
22 static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
23     72,
24     60};
25 
26 template <typename T>
is_aligned(T * p)27 __device__ __forceinline__ bool is_aligned(T* p) {
28   return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
29 }
30 
31 template <typename T>
load_store(T * dst,T * src,int64_t dst_offset,int64_t src_offset)32 __device__ __forceinline__ void load_store(
33     T* dst,
34     T* src,
35     int64_t dst_offset,
36     int64_t src_offset) {
37   using LT = at::native::memory::aligned_vector<T, kILP>;
38   ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
39 }
40 
41 template <int n>
42 struct TensorListMetadata {
43   const void* addresses[n][depth_to_max_tensors[n - 1]];
44   int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
45   unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
46   int block_to_chunk[depth_to_max_blocks[n - 1]];
47   int start_tensor_this_launch;
48 };
49 
50 template <typename scalar_vals_t, int n>
51 struct TensorListScalarListMetadata {
52   const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
53   int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
54   scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
55   unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
56   int block_to_chunk[depth_to_max_blocks[n - 1]];
57 };
58 
59 // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
60 // 4kb with `c10::complex<double>`
61 template <>
62 struct TensorListScalarListMetadata<c10::complex<double>, 1> {
63   const void* addresses[1]
64                        [depth_to_max_tensors_scalarlist_of_complex_double[0]];
65   int64_t
66       numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
67   c10::complex<double>
68       scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
69   unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
70   int block_to_chunk[depth_to_max_blocks[1 - 1]];
71 };
72 
73 template <>
74 struct TensorListScalarListMetadata<c10::complex<double>, 2> {
75   const void* addresses[2]
76                        [depth_to_max_tensors_scalarlist_of_complex_double[1]];
77   int64_t
78       numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
79   c10::complex<double>
80       scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
81   unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
82   int block_to_chunk[depth_to_max_blocks[2 - 1]];
83 };
84 
85 // NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
86 // whose each element is `at::Tensor` of 1 element representing the number of
87 // `step`s called so far.
88 template <int n>
89 struct FusedOptimizerTensorListMetadata {
90   const void* addresses[n][depth_to_max_tensors[n - 1]];
91   int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
92   const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
93   unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
94   int block_to_chunk[depth_to_max_blocks[n - 1]];
95   int start_tensor_this_launch;
96 };
97 
98 template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)99 C10_LAUNCH_BOUNDS_1(kBlockSize)
100 __global__ void multi_tensor_apply_kernel(
101     T tensorListMeta,
102     U callable,
103     ArgTypes... args) {
104   // Hand the chunk information to the user-supplied functor to process however
105   // it likes.
106   callable(kChunkSize, tensorListMeta, args...);
107 }
108 
109 } // namespace
110 
111 // multi_tensor_apply enables horizontal fusion across lists of tensors.
112 // For example, whereas you once had a for-loop of a + b = c, where a, b,
113 // and c are individual tensors in lists as, bs, and cs, you can now with
114 // fewer kernel launches compute as + bs = cs.
115 //
116 // You can also imagine bs to be a scalar list vs a tensor list.
117 //
118 // The function below takes in tensor lists, scalars, and a callable and
119 // chunks up the computation to launch as few kernels as possible by iterating
120 // through every "chunk" in every tensor (thus the nested for loops). In the
121 // simplest case, everything gets bundled into just one kernel launch, but
122 // due to blocksize constraints, we may need to launch multiple kernels.
123 // Each kernel launch is defined by one tensorListMeta construct, which we
124 // use to track and reset the necessary metadata for each launch.
125 template <int depth, typename scalar_T, typename T, typename... ArgTypes>
multi_tensor_apply(std::vector<std::vector<at::Tensor>> & tensor_lists,at::ArrayRef<Scalar> scalars,T callable,ArgTypes...args)126 void multi_tensor_apply(
127     std::vector<std::vector<at::Tensor>>& tensor_lists,
128     at::ArrayRef<Scalar> scalars,
129     T callable,
130     ArgTypes... args) {
131   TORCH_CHECK(
132       tensor_lists.size() == depth,
133       "Number of tensor lists has to match the depth.");
134   const size_t n_tensors = tensor_lists[0].size();
135   using scalar_vals_t = typename T::opmath_t;
136   TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
137 
138   int loc_block_info = 0;
139   int loc_tensor_info = 0;
140   for (size_t t = 0; t < n_tensors; t++) {
141     // short-circuit to avoid adding empty tensors to tensorListMeta
142     if (tensor_lists[0][t].numel() == 0) {
143       continue;
144     }
145     tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
146     tensorListMeta.numel_for_tensor[loc_tensor_info] =
147         tensor_lists[0][t].numel();
148     for (int d = 0; d < depth; d++) {
149       tensorListMeta.addresses[d][loc_tensor_info] =
150           tensor_lists[d][t].const_data_ptr();
151     }
152     loc_tensor_info++;
153 
154     // now we enter [chunking territory].
155     // we will launch a kernel when EITHER the blocks get filled up OR
156     // the tensors get filled up. There will always be at least one block
157     // per tensor since the zero-sized ones will not enter the loop, so
158     // the nested forloop within represents iterating through the chunks
159     // of a single tensor.
160     const auto numel = tensor_lists[0][t].numel();
161     const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
162     for (auto chunk = 0; chunk < chunks; chunk++) {
163       tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
164       tensorListMeta.block_to_chunk[loc_block_info] = chunk;
165       loc_block_info++;
166 
167       // a tensor is not considered full unless all its chunks have been
168       // processed
169       const bool tensors_full =
170           (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
171            chunk == chunks - 1);
172       const bool blocks_full =
173           (loc_block_info == depth_to_max_blocks[depth - 1]);
174 
175       if (tensors_full || blocks_full) {
176         multi_tensor_apply_kernel<<<
177             loc_block_info,
178             kBlockSize,
179             0,
180             at::cuda::getCurrentCUDAStream()>>>(
181             tensorListMeta, callable, args...);
182         C10_CUDA_KERNEL_LAUNCH_CHECK();
183 
184         // Reset.
185         loc_block_info = 0;
186         // all chunks have already been handled in the kernel
187         if (chunk == chunks - 1) {
188           loc_tensor_info = 0;
189         } else { // blocks were full and tensor chunks remain
190           tensorListMeta.numel_for_tensor[0] =
191               tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
192           tensorListMeta.scalar_vals[0] =
193               tensorListMeta.scalar_vals[loc_tensor_info - 1];
194           for (int d = 0; d < depth; d++) {
195             tensorListMeta.addresses[d][0] =
196                 tensorListMeta.addresses[d][loc_tensor_info - 1];
197           }
198           loc_tensor_info = 1;
199         }
200       }
201     }
202   }
203 
204   // note: [finishing what we started]
205   // if there's remaining work to be done but the tensors/blocks aren't full
206   // yet we are at the end, submit the kernel to do the work!
207   if (loc_block_info != 0) {
208     multi_tensor_apply_kernel<<<
209         loc_block_info,
210         kBlockSize,
211         0,
212         at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
213     C10_CUDA_KERNEL_LAUNCH_CHECK();
214   }
215 }
216 
217 template <int depth, typename T, typename... ArgTypes>
multi_tensor_apply(std::vector<std::vector<at::Tensor>> & tensor_lists,T callable,ArgTypes...args)218 void multi_tensor_apply(
219     std::vector<std::vector<at::Tensor>>& tensor_lists,
220     T callable,
221     ArgTypes... args) {
222   TORCH_CHECK(
223       tensor_lists.size() == depth,
224       "Number of tensor lists has to match the depth.");
225   const size_t n_tensors = tensor_lists[0].size();
226   TensorListMetadata<depth> tensorListMeta;
227   tensorListMeta.start_tensor_this_launch = 0;
228 
229   int loc_block_info = 0;
230   int loc_tensor_info = 0;
231   for (size_t t = 0; t < n_tensors; t++) {
232     // short-circuit to avoid adding empty tensors to tensorListMeta
233     if (tensor_lists[0][t].numel() == 0) {
234       continue;
235     }
236     tensorListMeta.numel_for_tensor[loc_tensor_info] =
237         tensor_lists[0][t].numel();
238     for (int d = 0; d < depth; d++) {
239       tensorListMeta.addresses[d][loc_tensor_info] =
240           tensor_lists[d][t].const_data_ptr();
241     }
242     loc_tensor_info++;
243 
244     // see note: [chunking territory].
245     const auto numel = tensor_lists[0][t].numel();
246     const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
247     for (auto chunk = 0; chunk < chunks; chunk++) {
248       tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
249       tensorListMeta.block_to_chunk[loc_block_info] = chunk;
250       loc_block_info++;
251 
252       const bool tensors_full =
253           (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
254            chunk == chunks - 1);
255       const bool blocks_full =
256           (loc_block_info == depth_to_max_blocks[depth - 1]);
257 
258       if (tensors_full || blocks_full) {
259         multi_tensor_apply_kernel<<<
260             loc_block_info,
261             kBlockSize,
262             0,
263             at::cuda::getCurrentCUDAStream()>>>(
264             tensorListMeta, callable, args...);
265         C10_CUDA_KERNEL_LAUNCH_CHECK();
266 
267         // Reset.
268         loc_block_info = 0;
269         if (chunk == chunks - 1) {
270           loc_tensor_info = 0;
271           tensorListMeta.start_tensor_this_launch = t + 1;
272         } else {
273           tensorListMeta.numel_for_tensor[0] =
274               tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
275           for (int d = 0; d < depth; d++) {
276             tensorListMeta.addresses[d][0] =
277                 tensorListMeta.addresses[d][loc_tensor_info - 1];
278           }
279           loc_tensor_info = 1;
280           tensorListMeta.start_tensor_this_launch = t;
281         }
282       }
283     }
284   }
285 
286   // see note: [finishing what we started]
287   if (loc_block_info != 0) {
288     multi_tensor_apply_kernel<<<
289         loc_block_info,
290         kBlockSize,
291         0,
292         at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
293     C10_CUDA_KERNEL_LAUNCH_CHECK();
294   }
295 }
296 
297 template <int depth, typename T, typename... ArgTypes>
multi_tensor_apply_for_fused_optimizer(std::vector<std::vector<at::Tensor>> & tensor_lists,at::TensorList state_steps,T callable,ArgTypes...args)298 void multi_tensor_apply_for_fused_optimizer(
299     std::vector<std::vector<at::Tensor>>& tensor_lists,
300     at::TensorList state_steps,
301     T callable,
302     ArgTypes... args) {
303   TORCH_CHECK(
304       tensor_lists.size() == depth,
305       "Number of tensor lists has to match the depth");
306   const auto num_tensors = tensor_lists[0].size();
307   FusedOptimizerTensorListMetadata<depth> tensorListMeta;
308 
309   int loc_block_info = 0;
310   int loc_tensor_info = 0;
311   for (const auto& tensor_index : c10::irange(num_tensors)) {
312     // short-circuit to avoid adding empty tensors to tensorListMeta
313     if (tensor_lists[0][tensor_index].numel() == 0) {
314       continue;
315     }
316     tensorListMeta.state_steps_addresses[loc_tensor_info] =
317         state_steps[tensor_index].const_data_ptr();
318     tensorListMeta.numel_for_tensor[loc_tensor_info] =
319         tensor_lists[0][tensor_index].numel();
320     for (const auto& d : c10::irange(depth)) {
321       tensorListMeta.addresses[d][loc_tensor_info] =
322           tensor_lists[d][tensor_index].const_data_ptr();
323     }
324     loc_tensor_info++;
325 
326     // see above note: [chunking territory]
327     const auto numel = tensor_lists[0][tensor_index].numel();
328     const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
329     TORCH_CHECK(chunks > -1);
330     for (const auto& chunk : c10::irange(chunks)) {
331       tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
332       tensorListMeta.block_to_chunk[loc_block_info] = chunk;
333       loc_block_info++;
334 
335       const auto tensor_full =
336           (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
337            chunk == chunks - 1);
338       const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
339 
340       if (tensor_full || blocks_full) {
341         multi_tensor_apply_kernel<<<
342             loc_block_info,
343             kBlockSize,
344             0,
345             at::cuda::getCurrentCUDAStream()>>>(
346             tensorListMeta, callable, args...);
347         C10_CUDA_KERNEL_LAUNCH_CHECK();
348 
349         // Reset.
350         loc_block_info = 0;
351         if (chunk == chunks - 1) {
352           loc_tensor_info = 0;
353         } else {
354           tensorListMeta.numel_for_tensor[0] =
355               tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
356           tensorListMeta.state_steps_addresses[0] =
357               tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
358           for (const auto& d : c10::irange(depth)) {
359             tensorListMeta.addresses[d][0] =
360                 tensorListMeta.addresses[d][loc_tensor_info - 1];
361           }
362           loc_tensor_info = 1;
363         }
364       }
365     }
366   }
367 
368   // see above note: [finishing what we've started]
369   if (loc_block_info != 0) {
370     multi_tensor_apply_kernel<<<
371         loc_block_info,
372         kBlockSize,
373         0,
374         at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
375     C10_CUDA_KERNEL_LAUNCH_CHECK();
376   }
377 }
378 
379 } // namespace at::native
380