xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cuda_fp16.h>
2 #include <type_traits>
3 
4 #include <ATen/ATen.h>
5 #include <ATen/Dispatch.h>
6 
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/detail/IndexUtils.cuh>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/cuda/MemoryAccess.cuh>
12 #include <ATen/native/cuda/PersistentSoftmax.cuh>
13 #include <ATen/native/cuda/block_reduce.cuh>
14 
15 #include <c10/cuda/CUDAGuard.h>
16 #include <c10/cuda/CUDAMathCompat.h>
17 #include <c10/cuda/CUDAStream.h>
18 
19 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
20 #include <ATen/native/nested/NestedTensorUtils.h>
21 
22 #ifndef USE_ROCM
23 #ifndef _WIN32
24 #include <cutlass/gemm/device/default_gemm_configuration.h>
25 #include <cutlass/gemm/device/gemm_grouped.h>
26 #include <cutlass/gemm/kernel/default_gemm_grouped.h>
27 #endif
28 #endif
29 
30 #include <ATen/NestedTensorImpl.h>
31 
32 #define BLOCK_DIM 256
33 #define GRID_DIM_Y 16
34 
35 namespace at {
36 namespace native {
37 
38 template <typename T>
remove_padding_transform0213_2(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)39 __global__ void remove_padding_transform0213_2(
40     const T* input,
41     T* output,
42     const int* offsets,
43     const int* input_sizes,
44     const int* output_sizes,
45     int output_dim,
46     const int batch_size) {
47   const int batch_id = blockIdx.x;
48   const int grid_id = blockIdx.y;
49   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
50   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
51   const int offset = offsets[batch_id];
52   const int* sizes_i = output_sizes + batch_id * output_dim;
53   const int numel_i = sizes_i[0] * sizes_i[1];
54   int input_offset =
55       batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3];
56   for (int ii = 0; ii < (numel_i / grainsize); ii++) {
57     const int i = ii * grainsize + tid;
58     const int i2 = i / sizes_i[1];
59     const int i13 = i % sizes_i[1];
60     const int i1 = i13 / (sizes_i[1] / input_sizes[1]);
61     const int i3 = i13 % (sizes_i[1] / input_sizes[1]);
62 
63     output[offset + i] = input
64         [input_offset + i1 * input_sizes[2] * input_sizes[3] +
65          i2 * input_sizes[3] + i3];
66   }
67   const int i = (numel_i / grainsize) * grainsize + tid;
68   if (i < numel_i) {
69     const int i2 = i / sizes_i[1];
70     const int i13 = i % sizes_i[1];
71     const int i1 = i13 / (sizes_i[1] / input_sizes[1]);
72     const int i3 = i13 % (sizes_i[1] / input_sizes[1]);
73     output[offset + i] = input
74         [input_offset + i1 * input_sizes[2] * input_sizes[3] +
75          i2 * input_sizes[3] + i3];
76   }
77 }
78 
79 template <typename T>
remove_padding_2(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)80 __global__ void remove_padding_2(
81     const T* input,
82     T* output,
83     const int* offsets,
84     const int* input_sizes,
85     const int* output_sizes,
86     int output_dim,
87     const int batch_size) {
88   const int batch_id = blockIdx.x;
89   const int grid_id = blockIdx.y;
90   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
91   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
92   const int offset = offsets[batch_id];
93   const int* sizes_i = output_sizes + batch_id * output_dim;
94   const int numel_i = sizes_i[0] * sizes_i[1];
95   int input_offset = batch_id * input_sizes[1] * input_sizes[2];
96   for (int ii = 0; ii < (numel_i / grainsize); ii++) {
97     const int i = ii * grainsize + tid;
98     const int i0 = i / sizes_i[1];
99     const int i1 = i % sizes_i[1];
100     const int i0_offset = i0 * input_sizes[2];
101     output[offset + i] = input[input_offset + i0_offset + i1];
102   }
103   const int i = (numel_i / grainsize) * grainsize + tid;
104   if (i < numel_i) {
105     const int i0 = i / sizes_i[1];
106     const int i1 = i % sizes_i[1];
107     const int i0_offset = i0 * input_sizes[2];
108     output[offset + i] = input[input_offset + i0_offset + i1];
109   }
110 }
111 
112 template <typename T>
remove_padding(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)113 __global__ void remove_padding(
114     const T* input,
115     T* output,
116     const int* offsets,
117     const int* input_sizes,
118     const int* output_sizes,
119     int output_dim,
120     const int batch_size) {
121   const int batch_id = blockIdx.x;
122   const int grid_id = blockIdx.y;
123   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
124   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
125   const int offset = offsets[batch_id];
126   const int* sizes_i = output_sizes + batch_id * output_dim;
127   const int numel_i = sizes_i[0] * sizes_i[1] * sizes_i[2];
128   int input_offset =
129       batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3];
130   for (int ii = 0; ii < (numel_i / grainsize); ii++) {
131     const int i = ii * grainsize + tid;
132     const int i0 = i / (sizes_i[1] * sizes_i[2]);
133     const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2];
134     const int i2 = i % sizes_i[2];
135     const int i0_offset = i0 * input_sizes[2] * input_sizes[3];
136     const int i1_offset = i1 * input_sizes[3];
137     output[offset + i] = input[input_offset + i0_offset + i1_offset + i2];
138   }
139   const int i = (numel_i / grainsize) * grainsize + tid;
140   if (i < numel_i) {
141     const int i0 = i / (sizes_i[1] * sizes_i[2]);
142     const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2];
143     const int i2 = i % sizes_i[2];
144     const int i0_offset = i0 * input_sizes[2] * input_sizes[3];
145     const int i1_offset = i1 * input_sizes[3];
146     output[offset + i] = input[input_offset + i0_offset + i1_offset + i2];
147   }
148 }
149 
150 template <typename T>
remove_padding_kernelLauncher(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)151 void remove_padding_kernelLauncher(
152     const T* input,
153     T* output,
154     const int* offsets,
155     const int* input_sizes,
156     const int* output_sizes,
157     int output_dim,
158     const int batch_size) {
159   dim3 grid;
160   grid.x = batch_size;
161   grid.y = GRID_DIM_Y;
162   at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
163   if (output_dim == 2) {
164     remove_padding_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
165         input,
166         output,
167         offsets,
168         input_sizes,
169         output_sizes,
170         output_dim,
171         batch_size);
172   } else {
173     remove_padding<T><<<grid, BLOCK_DIM, 0, stream>>>(
174         input,
175         output,
176         offsets,
177         input_sizes,
178         output_sizes,
179         output_dim,
180         batch_size);
181   }
182 }
183 
184 template <typename T>
remove_padding_transform0213_kernelLauncher(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)185 void remove_padding_transform0213_kernelLauncher(
186     const T* input,
187     T* output,
188     const int* offsets,
189     const int* input_sizes,
190     const int* output_sizes,
191     int output_dim,
192     const int batch_size) {
193   dim3 grid;
194   grid.x = batch_size;
195   grid.y = GRID_DIM_Y;
196   at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
197   TORCH_CHECK(
198       output_dim == 2,
199       "remove padding transform0213 only support output dim == 2");
200 
201   remove_padding_transform0213_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
202       input,
203       output,
204       offsets,
205       input_sizes,
206       output_sizes,
207       output_dim,
208       batch_size);
209 }
210 
211 template void remove_padding_kernelLauncher<float>(
212     const float* input,
213     float* output,
214     const int* offsets,
215     const int* input_sizes,
216     const int* output_sizes,
217     int output_dim,
218     const int batch_size);
219 
220 template void remove_padding_kernelLauncher<c10::Half>(
221     const c10::Half* input,
222     c10::Half* output,
223     const int* offsets,
224     const int* input_sizes,
225     const int* output_sizes,
226     int output_dim,
227     const int batch_size);
228 
229 template void remove_padding_transform0213_kernelLauncher<float>(
230     const float* input,
231     float* output,
232     const int* offsets,
233     const int* input_sizes,
234     const int* output_sizes,
235     int output_dim,
236     const int batch_size);
237 
238 template void remove_padding_transform0213_kernelLauncher<c10::Half>(
239     const c10::Half* input,
240     c10::Half* output,
241     const int* offsets,
242     const int* input_sizes,
243     const int* output_sizes,
244     int output_dim,
245     const int batch_size);
246 
247 template <typename T>
add_padding_1(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,const int batch_size)248 __global__ void add_padding_1(
249     const T* input,
250     T* output,
251     T padding_value,
252     const int* offsets,
253     const int* input_sizes,
254     int input_dim,
255     int output_sizes_1,
256     const int batch_size) {
257   const int batch_id = blockIdx.x;
258   const int grid_id = blockIdx.y;
259   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
260   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
261   const int* sizes_i = input_sizes + batch_id * input_dim;
262   const int batch_output_offset = batch_id * output_sizes_1;
263   for (int ii = 0; ii < (output_sizes_1 / grainsize); ii++) {
264     const int i = ii * grainsize + tid;
265     const int output_offset = batch_output_offset + i;
266     if (batch_id < batch_size && i < sizes_i[0]) {
267       const int batch_input_offset = offsets[batch_id];
268       output[output_offset] = input[batch_input_offset + i];
269     } else {
270       output[output_offset] = padding_value;
271     }
272   }
273   const int i = (output_sizes_1 / grainsize) * grainsize + tid;
274   if (i < output_sizes_1) {
275     const int output_offset = batch_output_offset + i;
276     if (batch_id < batch_size && (i < sizes_i[0])) {
277       const int batch_input_offset = offsets[batch_id];
278       output[output_offset] = input[batch_input_offset + i];
279     } else {
280       output[output_offset] = padding_value;
281     }
282   }
283 }
284 
285 template <typename T>
add_padding_2(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,int output_sizes_2,const int batch_size)286 __global__ void add_padding_2(
287     const T* input,
288     T* output,
289     T padding_value,
290     const int* offsets,
291     const int* input_sizes,
292     int input_dim,
293     int output_sizes_1,
294     int output_sizes_2,
295     const int batch_size) {
296   const int batch_id = blockIdx.x;
297   const int grid_id = blockIdx.y;
298   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
299   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
300   const int* sizes_i = input_sizes + batch_id * input_dim;
301   const int output_offset = batch_id * output_sizes_1 * output_sizes_2;
302   const int output_numel = output_sizes_1 * output_sizes_2;
303   for (int ii = 0; ii < (output_numel / grainsize); ii++) {
304     const int i = ii * grainsize + tid;
305     const int i0 = i / (output_sizes_2);
306     const int i1 = i - i0 * output_sizes_2;
307     if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1]) {
308       const int offset = offsets[batch_id];
309       const int input_offset = offset + i0 * sizes_i[1] + i1;
310       output[output_offset + i] = input[input_offset];
311     } else {
312       output[output_offset + i] = padding_value;
313     }
314   }
315   const int i = (output_numel / grainsize) * grainsize + tid;
316   if (i < output_numel) {
317     const int i0 = i / (output_sizes_2);
318     const int i1 = i - i0 * output_sizes_2;
319     if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1]) {
320       const int offset = offsets[batch_id];
321       const int input_offset = offset + i0 * sizes_i[1] + i1;
322       output[output_offset + i] = input[input_offset];
323     } else {
324       output[output_offset + i] = padding_value;
325     }
326   }
327 }
328 
329 template <typename T>
add_padding_3(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,int output_sizes_2,int output_sizes_3,const int batch_size)330 __global__ void add_padding_3(
331     const T* input,
332     T* output,
333     T padding_value,
334     const int* offsets,
335     const int* input_sizes,
336     int input_dim,
337     int output_sizes_1,
338     int output_sizes_2,
339     int output_sizes_3,
340     const int batch_size) {
341   const int batch_id = blockIdx.x;
342   const int grid_id = blockIdx.y;
343   const int tid = threadIdx.x + grid_id * BLOCK_DIM;
344   const int grainsize = GRID_DIM_Y * BLOCK_DIM;
345   const int* sizes_i = input_sizes + batch_id * input_dim;
346   const int output_offset =
347       batch_id * output_sizes_1 * output_sizes_2 * output_sizes_3;
348   const int output_numel = output_sizes_1 * output_sizes_2 * output_sizes_3;
349   for (int ii = 0; ii < (output_numel / grainsize); ii++) {
350     const int i = ii * grainsize + tid;
351     const int i0 = i / (output_sizes_2 * output_sizes_3);
352     const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
353     const int i2 = i % output_sizes_3;
354     if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
355         i2 < sizes_i[2]) {
356       const int offset = offsets[batch_id];
357       const int input_offset =
358           offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
359       output[output_offset + i] = input[input_offset];
360     } else {
361       output[output_offset + i] = padding_value;
362     }
363   }
364   const int i = (output_numel / grainsize) * grainsize + tid;
365   if (i < output_numel) {
366     const int i0 = i / (output_sizes_2 * output_sizes_3);
367     const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
368     const int i2 = i % output_sizes_3;
369     if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
370         i2 < sizes_i[2]) {
371       const int offset = offsets[batch_id];
372       const int input_offset =
373           offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
374       output[output_offset + i] = input[input_offset];
375     } else {
376       output[output_offset + i] = padding_value;
377     }
378   }
379 }
380 
381 template <typename T>
add_padding_kernelLauncher(T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,const std::vector<int64_t> & output_sizes,const int batch_size,const int output_batch_size)382 void add_padding_kernelLauncher(
383     T* input, // [batch_size x None]
384     T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
385     T padding_value,
386     const int* offsets,
387     const int* input_sizes,
388     int input_dim,
389     const std::vector<int64_t>& output_sizes,
390     const int batch_size,
391     const int output_batch_size) {
392   at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
393   dim3 grid;
394   grid.x = output_batch_size;
395   grid.y = GRID_DIM_Y;
396   if (input_dim == 1) {
397     add_padding_1<T><<<grid, BLOCK_DIM, 0, stream>>>(
398         input,
399         output,
400         padding_value,
401         offsets,
402         input_sizes,
403         input_dim,
404         output_sizes[1],
405         batch_size);
406   }
407   if (input_dim == 2) {
408     add_padding_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
409         input,
410         output,
411         padding_value,
412         offsets,
413         input_sizes,
414         input_dim,
415         output_sizes[1],
416         output_sizes[2],
417         batch_size);
418   }
419   if (input_dim == 3) {
420     add_padding_3<T><<<grid, BLOCK_DIM, 0, stream>>>(
421         input,
422         output,
423         padding_value,
424         offsets,
425         input_sizes,
426         input_dim,
427         output_sizes[1],
428         output_sizes[2],
429         output_sizes[3],
430         batch_size);
431   }
432 }
433 
434 template void add_padding_kernelLauncher<double>(
435     double* input,
436     double* output,
437     double padding_value,
438     const int* offsets,
439     const int* input_sizes,
440     int input_dim,
441     const std::vector<int64_t>& output_sizes,
442     const int batch_size,
443     const int output_batch_size);
444 
445 template void add_padding_kernelLauncher<float>(
446     float* input,
447     float* output,
448     float padding_value,
449     const int* offsets,
450     const int* input_sizes,
451     int input_dim,
452     const std::vector<int64_t>& output_sizes,
453     const int batch_size,
454     const int output_batch_size);
455 
456 template void add_padding_kernelLauncher<c10::Half>(
457     c10::Half* input,
458     c10::Half* output,
459     c10::Half padding_value,
460     const int* offsets,
461     const int* input_sizes,
462     int input_dim,
463     const std::vector<int64_t>& output_sizes,
464     const int batch_size,
465     const int output_batch_size);
466 
467 // NB: The following code covers jagged <-> padded dense conversions and was lifted
468 // from fbgemm_gpu. For more details, see
469 // https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/src/jagged_tensor_ops
470 
471 // Passing lambda exp argument by value instead of by reference to avoid
472 // "internal compiler error: in maybe_undo_parenthesized_ref" error for specific
473 // compiler version.
474 #define JAGGED_TENSOR_DISPATCH_DIMS()                                         \
475   AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \
476     switch (num_jagged_dim) {                                                 \
477       case 1:                                                                 \
478         INVOKE_KERNEL_WITH_DIM(1);                                            \
479         break;                                                                \
480       case 2:                                                                 \
481         INVOKE_KERNEL_WITH_DIM(2);                                            \
482         break;                                                                \
483       case 3:                                                                 \
484         INVOKE_KERNEL_WITH_DIM(3);                                            \
485         break;                                                                \
486       case 4:                                                                 \
487         INVOKE_KERNEL_WITH_DIM(4);                                            \
488         break;                                                                \
489       case 5:                                                                 \
490         INVOKE_KERNEL_WITH_DIM(5);                                            \
491         break;                                                                \
492       default:                                                                \
493         TORCH_CHECK(                                                          \
494             false, "unsupported number of jagged dim ", num_jagged_dim);      \
495     }                                                                         \
496   });
497 
torch_tensor_device_name(const at::Tensor & ten)498 inline std::string torch_tensor_device_name(const at::Tensor& ten) {
499   return c10::DeviceTypeName(ten.device().type());
500 }
501 
torch_tensor_device_name(const std::optional<at::Tensor> & ten)502 inline std::string torch_tensor_device_name(
503     const std::optional<at::Tensor>& ten) {
504   if (ten.has_value()) {
505     return torch_tensor_device_name(ten.value());
506   } else {
507     return "N/A";
508   }
509 }
510 
torch_tensor_on_cuda_gpu_check(const at::Tensor & ten)511 inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) {
512   return ten.is_cuda();
513 }
514 
torch_tensor_on_cuda_gpu_check(const std::optional<at::Tensor> & ten)515 inline bool torch_tensor_on_cuda_gpu_check(
516     const std::optional<at::Tensor>& ten) {
517   return !ten.has_value() || torch_tensor_on_cuda_gpu_check(ten.value());
518 }
519 
520 #define TENSOR_ON_CUDA_GPU(x)                                  \
521   TORCH_CHECK(                                                 \
522       torch_tensor_on_cuda_gpu_check(x),                       \
523       #x " must be a CUDA tensor; it is currently on device ", \
524       torch_tensor_device_name(x))
525 
526 // A wrapper class for passing dynamically sized dimension information (e.g.
527 // tensor.dims()) from the host to device.
528 constexpr size_t kStackArrayMaxDims = 5;
529 
530 template <typename T>
531 struct StackArray {
532   T vals[kStackArrayMaxDims];
533   size_t ndim;
534 };
535 
536 // Warp size
537 #ifdef USE_ROCM
538 static constexpr int32_t kWarpSize = 64;
539 #else
540 static constexpr int32_t kWarpSize = 32;
541 #endif
542 // Max thread num in one thread block
543 static constexpr int32_t kMaxThreads = 1024;
544 
545 #define DEVICE_INLINE __device__ C10_ALWAYS_INLINE
546 
div_round_up(int32_t a,int32_t b)547 __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) {
548   return (a + b - 1) / b;
549 }
550 
round_down(int32_t a,int32_t b)551 __host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) {
552   return a / b * b;
553 }
554 
check_shape_and_partition_(const Tensor & values,const std::vector<Tensor> & offsets,const Tensor & dense_tensor)555 inline std::tuple<dim3, dim3, StackArray<int64_t>> check_shape_and_partition_(
556     const Tensor& values,
557     const std::vector<Tensor>& offsets,
558     const Tensor& dense_tensor) {
559   const int outer_dense_size = dense_tensor.size(0);
560   TORCH_CHECK(
561       outer_dense_size == offsets[0].numel() - 1,
562       "outer_dense_size, ",
563       outer_dense_size,
564       " != offsets[0].numel() - 1, ",
565       offsets[0].numel() - 1);
566   const int inner_dense_size = dense_tensor.size(-1);
567   TORCH_CHECK(
568       inner_dense_size == values.size(-1),
569       "inner_dense_size, ",
570       inner_dense_size,
571       " != values.size(-1), ",
572       values.size(-1));
573   const int jagged_folded_size =
574       dense_tensor.numel() / (outer_dense_size * inner_dense_size);
575 
576   const int threads_x =
577       inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size;
578   const int threads_y = kMaxThreads / kWarpSize;
579   const dim3 blocks(
580       div_round_up(outer_dense_size * jagged_folded_size, threads_y));
581 
582   StackArray<int64_t> jagged_dims_tensor;
583   const int num_jagged_dim = dense_tensor.dim() - 2;
584   TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims);
585   jagged_dims_tensor.ndim = num_jagged_dim;
586   std::memcpy(
587       &(jagged_dims_tensor.vals[0]),
588       dense_tensor.sizes().data() + 1,
589       num_jagged_dim * sizeof(int64_t));
590   return {dim3(threads_x, threads_y), blocks, jagged_dims_tensor};
591 }
592 
593 template <int NUM_JAGGED_DIM, typename index_t>
walk_down_tensor_storage_tree_(int & offset,const int flattened_jagged_idx,const StackArray<int64_t> & jagged_dims,const StackArray<index_t * > & x_offsets)594 DEVICE_INLINE bool walk_down_tensor_storage_tree_(
595     int& offset,
596     const int flattened_jagged_idx,
597     const StackArray<int64_t>& jagged_dims,
598     const StackArray<index_t*>& x_offsets) {
599   // compute coorindates
600   int jagged_coords[NUM_JAGGED_DIM];
601   int j_temp = flattened_jagged_idx;
602 #pragma unroll
603   for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) {
604     const int jagged_size = jagged_dims.vals[d];
605     jagged_coords[d] = j_temp % jagged_size;
606     j_temp /= jagged_size;
607   }
608 
609   // walk down the tree
610   bool is_zero = false;
611 #pragma unroll
612   for (int d = 0; d < NUM_JAGGED_DIM; ++d) {
613     const int begin = x_offsets.vals[d][offset];
614     const int end = x_offsets.vals[d][offset + 1];
615     if (jagged_coords[d] >= end - begin) {
616       is_zero = true;
617       break;
618     }
619     offset = begin + jagged_coords[d];
620   }
621   return is_zero;
622 }
623 
624 // output = f(x, y) where x is jagged, y is dense, and output is dense.
625 // A generic elementwise operation between a jagged tensor and a dense tensor
626 // This kernel assumes jagged dims are clustered together, preceded by outer
627 // dense dimensions and followed by inner dense dimensions.
628 // The outer/inner dense dimensions, and jagged dimensions in between are
629 // assumed to be folded so physically the dense tensor is 3D and the value of
630 // jagged tensor is 2D.
631 // To support arbitrary number of jagged dimensions, we pass a vector of
632 // pointers to offset tensors (this is ugly and probably we can use nested
633 // tensor here).
634 // This kernel parallelizes the (folded) inner dense dimension across
635 // blockDim.x so the inner dense dimension should be similar to or bigger than
636 // warp size.
637 // We rely on compiler unrolling the compiler time constant NUM_JAGGED_DIM.
638 template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
639 __global__
__launch_bounds__(kMaxThreads)640 __launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_(
641     const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
642         x_values,
643     StackArray<index_t*> x_offsets,
644     const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y,
645     at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> output,
646     StackArray<int64_t> jagged_dims,
647     F f,
648     const scalar_t padding_value) {
649   const int outer_dense_size = y.size(0);
650   const int jagged_folded_size = y.size(1);
651   const int inner_dense_size = y.size(2);
652 
653   const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y;
654   const int outer_stride = gridDim.x * blockDim.y;
655   for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size;
656        outer += outer_stride) {
657     const int oidx = outer / jagged_folded_size;
658     const int jidx = outer % jagged_folded_size;
659 
660     int offset = oidx;
661     const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>(
662         offset, jidx, jagged_dims, x_offsets);
663 
664     if (is_zero) {
665       int iidx;
666       for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
667            iidx += blockDim.x) {
668         output[oidx][jidx][2 * iidx] =
669             f(padding_value, y[oidx][jidx][2 * iidx]);
670         output[oidx][jidx][2 * iidx + 1] =
671             f(padding_value, y[oidx][jidx][2 * iidx + 1]);
672       }
673       if (iidx * 2 + 1 == inner_dense_size) {
674         output[oidx][jidx][2 * iidx] =
675             f(padding_value, y[oidx][jidx][2 * iidx]);
676       }
677     } else {
678       int iidx;
679       for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
680            iidx += blockDim.x) {
681         output[oidx][jidx][2 * iidx] =
682             f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
683         output[oidx][jidx][2 * iidx + 1] =
684             f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]);
685       }
686       if (iidx * 2 + 1 == inner_dense_size) {
687         output[oidx][jidx][2 * iidx] =
688             f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
689       }
690     }
691   }
692 }
693 
694 template <typename scalar_t, typename F>
jagged_dense_elementwise_dense_output_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output,F f,const scalar_t padding_value=static_cast<scalar_t> (0))695 void jagged_dense_elementwise_dense_output_(
696     const Tensor& x_values,
697     const std::vector<Tensor>& x_offsets,
698     const Tensor& y,
699     const Tensor& output,
700     F f,
701     const scalar_t padding_value = static_cast<scalar_t>(0)) {
702   TENSOR_ON_CUDA_GPU(x_values);
703   for (auto& x_offset : x_offsets) {
704     TENSOR_ON_CUDA_GPU(x_offset);
705   }
706 
707   const int num_jagged_dim = y.dim() - 2;
708   TORCH_CHECK(
709       x_offsets.size() == static_cast<size_t>(num_jagged_dim),
710       "x_offsets.size(), ",
711       x_offsets.size(),
712       " != num_jagged_dim ",
713       num_jagged_dim);
714 
715   if (y.numel() == 0) {
716     return;
717   }
718 
719   dim3 threads, blocks;
720   StackArray<int64_t> jagged_dims_tensor;
721   std::tie(threads, blocks, jagged_dims_tensor) =
722       check_shape_and_partition_(x_values, x_offsets, y);
723 
724   // Canonicalize y and output to 3D, collapsing jagged dimensions.
725   const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
726   Tensor output_reshaped = output.view(y_reshaped.sizes());
727 
728 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM)                                \
729   {                                                                           \
730     std::vector<Tensor> x_offsets_contig;                                     \
731     x_offsets_contig.resize(num_jagged_dim);                                  \
732     StackArray<index_t*> x_offset_ptrs;                                       \
733     x_offset_ptrs.ndim = num_jagged_dim;                                      \
734     for (int d = 0; d < num_jagged_dim; ++d) {                                \
735       x_offsets_contig[d] = x_offsets[d].contiguous();                        \
736       x_offset_ptrs.vals[d] =                                                 \
737           x_offsets_contig[d].template data_ptr<index_t>();                   \
738     }                                                                         \
739     jagged_dense_elementwise_dense_output_kernel_<NUM_JAGGED_DIM, index_t>    \
740         <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(           \
741             x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
742             x_offset_ptrs,                                                    \
743             y_reshaped                                                        \
744                 .packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),     \
745             output_reshaped                                                   \
746                 .packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),     \
747             jagged_dims_tensor,                                               \
748             f,                                                                \
749             padding_value);                                                   \
750   }
751 
752   JAGGED_TENSOR_DISPATCH_DIMS();
753   C10_CUDA_KERNEL_LAUNCH_CHECK();
754 
755 #undef INVOKE_KERNEL_WITH_DIM
756 }
757 
758 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM)                                 \
759   {                                                                            \
760     auto [threads, blocks, jagged_dims_tensor] =                               \
761         check_shape_and_partition_(x_values, x_offsets, y);                    \
762     blocks.x = div_round_up(x_values.size(0), threads.y);                      \
763     std::vector<Tensor> x_offsets_contig;                                      \
764     x_offsets_contig.resize(num_jagged_dim);                                   \
765     StackArray<index_t*> x_offset_ptrs;                                        \
766     x_offset_ptrs.ndim = num_jagged_dim;                                       \
767     StackArray<int64_t> x_offset_sizes;                                        \
768     x_offset_sizes.ndim = num_jagged_dim;                                      \
769     for (int d = 0; d < num_jagged_dim; ++d) {                                 \
770       x_offsets_contig[d] = x_offsets[d].contiguous();                         \
771       x_offset_ptrs.vals[d] =                                                  \
772           x_offsets_contig[d].template data_ptr<index_t>();                    \
773       x_offset_sizes.vals[d] = x_offsets[d].numel();                           \
774     }                                                                          \
775     jagged_dense_dense_elementwise_jagged_output_kernel_<                      \
776         NUM_JAGGED_DIM,                                                        \
777         index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(    \
778         x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),      \
779         x_offset_ptrs,                                                         \
780         x_offset_sizes,                                                        \
781         y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),    \
782         y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),    \
783         output_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
784         jagged_dims_tensor,                                                    \
785         [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/)            \
786             -> scalar_t { return f(x, y); });                                  \
787   }
788 
789 template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
790 __global__
__launch_bounds__(kMaxThreads)791 __launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_(
792     const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
793         x_values,
794     StackArray<index_t*> x_offsets,
795     StackArray<int64_t> x_offsets_sizes,
796     const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_0,
797     const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_1,
798     at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
799         output_values,
800     StackArray<int64_t> jagged_dims,
801     F f) {
802   const int outer_dense_size = y_0.size(0);
803   const int inner_dense_size = y_0.size(2);
804   const int nnz = x_values.size(0);
805 
806   const int offset_begin = blockIdx.x * blockDim.y + threadIdx.y;
807   const int offset_stride = gridDim.x * blockDim.y;
808   for (int offset = offset_begin; offset < nnz; offset += offset_stride) {
809     int offset_temp = offset;
810     int jidx = 0;
811     bool truncated = false;
812     int dim_prod = 1;
813 #pragma unroll
814     for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) {
815       // Binary search the first that is bigger than offset
816       int count = x_offsets_sizes.vals[d] - 1;
817       int first = 1;
818       while (count > 0) {
819         int idx = first;
820         int step = count / 2;
821         idx += step;
822         if (x_offsets.vals[d][idx] <= offset_temp) {
823           first = ++idx;
824           count -= step + 1;
825         } else {
826           count = step;
827         }
828       }
829 
830       --first;
831       int coord = offset_temp - x_offsets.vals[d][first];
832       if (coord >= jagged_dims.vals[d]) {
833         truncated = true;
834         break;
835       }
836       jidx += coord * dim_prod;
837       dim_prod *= jagged_dims.vals[d];
838       offset_temp = first;
839     }
840 
841     if (offset_temp >= outer_dense_size) {
842       // This can happen when values have more elements than the last element of
843       // offset
844       truncated = true;
845     }
846     if (!truncated) {
847       const int oidx = offset_temp;
848       int iidx;
849       for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
850            iidx += blockDim.x) {
851         output_values[offset][2 * iidx] =
852             f(x_values[offset][2 * iidx],
853               y_0[oidx][jidx][2 * iidx],
854               y_1[oidx][jidx][2 * iidx]);
855         output_values[offset][2 * iidx + 1] =
856             f(x_values[offset][2 * iidx + 1],
857               y_0[oidx][jidx][2 * iidx + 1],
858               y_1[oidx][jidx][2 * iidx + 1]);
859       }
860       if (iidx * 2 + 1 == inner_dense_size) {
861         output_values[offset][2 * iidx] =
862             f(x_values[offset][2 * iidx],
863               y_0[oidx][jidx][2 * iidx],
864               y_1[oidx][jidx][2 * iidx]);
865       }
866     } else {
867       int iidx;
868       for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
869            iidx += blockDim.x) {
870         output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0);
871         output_values[offset][2 * iidx + 1] =
872             f(x_values[offset][2 * iidx + 1], 0, 0);
873       }
874       if (iidx * 2 + 1 == inner_dense_size) {
875         output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0);
876       }
877     }
878   }
879 }
880 
881 ///@addtogroup jagged-tensor-ops-cuda
882 template <typename scalar_t, typename F>
jagged_dense_elementwise_jagged_output_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output_values,F f)883 void jagged_dense_elementwise_jagged_output_(
884     const Tensor& x_values,
885     const std::vector<Tensor>& x_offsets,
886     const Tensor& y,
887     const Tensor& output_values,
888     F f) {
889   TENSOR_ON_CUDA_GPU(x_values);
890   for (auto& x_offset : x_offsets) {
891     TENSOR_ON_CUDA_GPU(x_offset);
892   }
893 
894   const int num_jagged_dim = y.dim() - 2;
895   TORCH_CHECK(
896       x_offsets.size() == static_cast<size_t>(num_jagged_dim),
897       "x_offsets.size(), ",
898       x_offsets.size(),
899       " != num_jagged_dim, ",
900       num_jagged_dim);
901 
902   if (y.numel() == 0 || x_values.numel() == 0) {
903     return;
904   }
905 
906   // Canonicalize y to 3D, collapsing jagged dimensions.
907   const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
908 
909   JAGGED_TENSOR_DISPATCH_DIMS();
910   C10_CUDA_KERNEL_LAUNCH_CHECK();
911 }
912 
913 #undef INVOKE_KERNEL_WITH_DIM
914 
915 template <typename T>
916 struct SharedMemory;
917 
918 template <>
919 struct SharedMemory<int64_t> {
getPointerat::native::SharedMemory920   __device__ int64_t* getPointer() {
921     extern __shared__ int64_t s_int64_t[];
922     return s_int64_t;
923   }
924 };
925 
926 template <>
927 struct SharedMemory<int32_t> {
getPointerat::native::SharedMemory928   __device__ int32_t* getPointer() {
929     extern __shared__ int32_t s_int32_t[];
930     return s_int32_t;
931   }
932 };
933 
934 template <typename index_t>
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_(const at::PackedTensorAccessor32<index_t,1,at::RestrictPtrTraits> offsets,at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> rows,at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> cols,int nnz,int B)935 __global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_(
936     const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
937     at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
938     at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
939     int nnz,
940     int B) {
941   struct SharedMemory<index_t> smem;
942   index_t* offsets_sh = smem.getPointer();
943 
944   for (int i = threadIdx.x; i < B + 1; i += blockDim.x) {
945     offsets_sh[i] = offsets[i];
946   }
947   __syncthreads();
948   int row = threadIdx.x + blockIdx.x * blockDim.x;
949   if (row >= nnz)
950     return;
951   int first = -1;
952   int count = B - 1;
953   first = 1;
954   while (count > 0) {
955     int idx = first;
956     int step = count / 2;
957     idx += step;
958     if (offsets_sh[idx] <= row) {
959       first = ++idx;
960       count -= step + 1;
961     } else {
962       count = step;
963     }
964   }
965   --first;
966 
967   int dense_row = first;
968   int offset = offsets_sh[dense_row];
969   int dense_col = row - offset;
970   rows[row] = dense_row;
971   cols[row] = dense_col;
972 }
973 
974 struct VecType128 {
975   typedef float4 TType; // Transaction Type
976   typedef struct __align__(16) {
977     __half a, b, c, d, w, x, y, z;
978   }
979   half8;
980 
981   union Data {
982     half8 val;
983     TType mask;
984   } data;
985 
VecType128at::native::VecType128986   __device__ VecType128() {
987     data.mask = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
988   }
989 };
990 
991 struct VecType64 {
992   typedef float2 TType; // Transaction Type
993   typedef struct __align__(8) {
994     __half a, b, c, d;
995   }
996   half4;
997 
998   union Data {
999     half4 val;
1000     TType mask;
1001   } data;
1002 
VecType64at::native::VecType641003   __device__ VecType64() {
1004     data.mask = make_float2(0.0f, 0.0f);
1005   }
1006 };
1007 
1008 struct VecType32 {
1009   typedef float TType; // Transaction Type
1010 
1011   union Data {
1012     __half2 val;
1013     TType mask;
1014   } data;
1015 
VecType32at::native::VecType321016   __device__ VecType32() {
1017     data.mask = 0.0f;
1018   }
1019 };
1020 
1021 template <typename F>
f128(VecType128 & v_out,const VecType128 & x,const VecType128 & y0,const VecType128 & y1,F f)1022 __device__ void f128(
1023     VecType128& v_out,
1024     const VecType128& x,
1025     const VecType128& y0,
1026     const VecType128& y1,
1027     F f) {
1028   v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a);
1029   v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b);
1030   v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c);
1031   v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d);
1032   v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w);
1033   v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x);
1034   v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y);
1035   v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z);
1036 }
1037 
1038 template <typename F>
f64(VecType64 & v_out,const VecType64 & x,const VecType64 & y0,const VecType64 & y1,F f)1039 __device__ void f64(
1040     VecType64& v_out,
1041     const VecType64& x,
1042     const VecType64& y0,
1043     const VecType64& y1,
1044     F f) {
1045   v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a);
1046   v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b);
1047   v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c);
1048   v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d);
1049 }
1050 
1051 template <typename F>
f32(VecType32 & v_out,const VecType32 & x,const VecType32 & y0,const VecType32 & y1,F f)1052 __device__ void f32(
1053     VecType32& v_out,
1054     const VecType32& x,
1055     const VecType32& y0,
1056     const VecType32& y1,
1057     F f) {
1058   v_out.data.val = __halves2half2(
1059       f(__low2half(x.data.val),
1060         __low2half(y0.data.val),
1061         __low2half(y1.data.val)),
1062       f(__high2half(x.data.val),
1063         __high2half(y0.data.val),
1064         __high2half(y1.data.val)));
1065 }
1066 
1067 template <typename index_t, typename F>
jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(at::PackedTensorAccessor32<c10::Half,2,at::RestrictPtrTraits> values,const at::PackedTensorAccessor32<c10::Half,2,at::RestrictPtrTraits> x_values,const at::PackedTensorAccessor32<c10::Half,3,at::RestrictPtrTraits> y0,const at::PackedTensorAccessor32<c10::Half,3,at::RestrictPtrTraits> y1,const at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> rows,const at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> cols,const int nnz,const int E,F f)1068 __global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(
1069     at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits> values,
1070     const at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits>
1071         x_values,
1072     const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y0,
1073     const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y1,
1074     const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
1075     const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
1076     const int nnz,
1077     const int E,
1078     F f) {
1079   int values_row = threadIdx.y + blockIdx.y * blockDim.y;
1080   if (values_row >= nnz)
1081     return;
1082   for (int real_row = values_row; real_row < nnz;
1083        real_row += blockDim.y * gridDim.y) {
1084     int dense_row = rows[real_row];
1085     int dense_col = cols[real_row];
1086     __half* values_ptr = reinterpret_cast<__half*>(&values[real_row][0]);
1087     const __half* x_ptr =
1088         reinterpret_cast<const __half*>(&x_values[real_row][0]);
1089     const __half* y0_ptr =
1090         reinterpret_cast<const __half*>(&y0[dense_row][dense_col][0]);
1091     const __half* y1_ptr =
1092         reinterpret_cast<const __half*>(&y1[dense_row][dense_col][0]);
1093     if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) &&
1094         (dense_col < y1.size(1)) && (dense_row < y1.size(0)) &&
1095         (dense_col >= 0) && (dense_row >= 0)) {
1096       for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) {
1097         VecType128 v_x, v_out, v_y0, v_y1;
1098         v_x.data.mask =
1099             (reinterpret_cast<const VecType128::TType*>(x_ptr))[tid];
1100         v_y0.data.mask =
1101             (reinterpret_cast<const VecType128::TType*>(y0_ptr))[tid];
1102         v_y1.data.mask =
1103             (reinterpret_cast<const VecType128::TType*>(y1_ptr))[tid];
1104         f128(v_out, v_x, v_y0, v_y1, f);
1105         (reinterpret_cast<VecType128::TType*>(values_ptr))[tid] =
1106             v_out.data.mask;
1107       }
1108       for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4;
1109            tid += blockDim.x) {
1110         VecType64 v_x, v_out, v_y0, v_y1;
1111         v_x.data.mask = (reinterpret_cast<const VecType64::TType*>(x_ptr))[tid];
1112         v_y0.data.mask =
1113             (reinterpret_cast<const VecType64::TType*>(y0_ptr))[tid];
1114         v_y1.data.mask =
1115             (reinterpret_cast<const VecType64::TType*>(y1_ptr))[tid];
1116         f64(v_out, v_x, v_y0, v_y1, f);
1117         (reinterpret_cast<VecType64::TType*>(values_ptr))[tid] =
1118             v_out.data.mask;
1119       }
1120       for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2;
1121            tid += blockDim.x) {
1122         VecType32 v_x, v_out, v_y0, v_y1;
1123         v_x.data.mask = (reinterpret_cast<const VecType32::TType*>(x_ptr))[tid];
1124         v_y0.data.mask =
1125             (reinterpret_cast<const VecType32::TType*>(y0_ptr))[tid];
1126         v_y1.data.mask =
1127             (reinterpret_cast<const VecType32::TType*>(y1_ptr))[tid];
1128         f32(v_out, v_x, v_y0, v_y1, f);
1129         (reinterpret_cast<VecType32::TType*>(values_ptr))[tid] =
1130             v_out.data.mask;
1131       }
1132       for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) {
1133         auto v_x = static_cast<__half>(x_ptr[tid]);
1134         auto v_y0 = static_cast<__half>(y0_ptr[tid]);
1135         auto v_y1 = static_cast<__half>(y1_ptr[tid]);
1136         values_ptr[tid] = f(v_x, v_y0, v_y1);
1137       }
1138     } else {
1139       for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) {
1140         VecType128 v_x, v_out, v_y0, v_y1;
1141         v_x.data.mask =
1142             (reinterpret_cast<const VecType128::TType*>(x_ptr))[tid];
1143         f128(v_out, v_x, v_y0, v_y1, f);
1144         (reinterpret_cast<VecType128::TType*>(values_ptr))[tid] =
1145             v_out.data.mask;
1146       }
1147       for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4;
1148            tid += blockDim.x) {
1149         VecType64 v_x, v_out, v_y0, v_y1;
1150         v_x.data.mask = (reinterpret_cast<const VecType64::TType*>(x_ptr))[tid];
1151         f64(v_out, v_x, v_y0, v_y1, f);
1152         (reinterpret_cast<VecType64::TType*>(values_ptr))[tid] =
1153             v_out.data.mask;
1154       }
1155       for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2;
1156            tid += blockDim.x) {
1157         VecType32 v_x, v_out, v_y0, v_y1;
1158         v_x.data.mask = (reinterpret_cast<const VecType32::TType*>(x_ptr))[tid];
1159         f32(v_out, v_x, v_y0, v_y1, f);
1160         (reinterpret_cast<VecType32::TType*>(values_ptr))[tid] =
1161             v_out.data.mask;
1162       }
1163       for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) {
1164         auto v_x = static_cast<__half>(x_ptr[tid]);
1165         values_ptr[tid] = f(v_x, __half{}, __half{});
1166       }
1167     }
1168   }
1169 }
1170 
1171 // Check to see if the inputs to the op are amenable to the fast path
jagged_dense_dense_elementwise_jagged_output_matches_opt(const int & num_jagged_dim,const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y_0_reshaped,const Tensor & y_1_reshaped,const Tensor & output_values)1172 inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
1173     const int& num_jagged_dim,
1174     const Tensor& x_values,
1175     const std::vector<Tensor>& x_offsets,
1176     const Tensor& y_0_reshaped,
1177     const Tensor& y_1_reshaped,
1178     const Tensor& output_values) {
1179   bool matches = true;
1180   matches &= (num_jagged_dim == 1);
1181 
1182   // Unit stride embedding dim
1183   matches &= (x_values.stride(-1) == 1);
1184   matches &= (output_values.stride(-1) == 1);
1185   matches &= (y_0_reshaped.stride(-1) == 1);
1186   matches &= (y_1_reshaped.stride(-1) == 1);
1187 
1188   // Each row is aligned to 128-bit
1189   matches &= (x_values.stride(-2) % 8 == 0);
1190   matches &= (output_values.stride(-2) % 8 == 0);
1191   matches &= (y_0_reshaped.stride(-2) % 8 == 0);
1192   matches &= (y_1_reshaped.stride(-2) % 8 == 0);
1193 
1194   // Base addresses aligned to 128-bit
1195   matches &= (reinterpret_cast<uint64_t>(x_values.data_ptr()) % 16 == 0);
1196   matches &= (reinterpret_cast<uint64_t>(output_values.data_ptr()) % 16 == 0);
1197   matches &= (reinterpret_cast<uint64_t>(y_0_reshaped.data_ptr()) % 16 == 0);
1198   matches &= (reinterpret_cast<uint64_t>(y_1_reshaped.data_ptr()) % 16 == 0);
1199 
1200   // Rows and col fit into int32_t
1201   matches &= (y_0_reshaped.size(0) < INT_MAX);
1202   matches &= (y_0_reshaped.size(1) < INT_MAX);
1203 
1204   int max_shared_bytes;
1205 #ifndef USE_ROCM
1206   C10_CUDA_CHECK(cudaDeviceGetAttribute(
1207       &max_shared_bytes,
1208       cudaDevAttrMaxSharedMemoryPerBlockOptin,
1209       y_0_reshaped.get_device()));
1210 #else
1211   // MI100 has 64 KB local memory (shared memory) per workgroup
1212   max_shared_bytes = 64 << 10;
1213 #endif
1214   int shared_kb = max_shared_bytes >> 10;
1215 #ifndef USE_ROCM
1216   // Use 2/3 of the available GPU shared mem; leave rooms for L1$.
1217   int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
1218   TORCH_CHECK(used_shared_kb > 0);
1219 #else
1220   // MI100 has independent shared mem and L1
1221   int used_shared_kb = shared_kb;
1222 #endif
1223   int used_shared_bytes = used_shared_kb << 10;
1224   AT_DISPATCH_INDEX_TYPES(
1225       x_offsets[0].scalar_type(), "check_shared_memory", [&] {
1226         auto B = y_0_reshaped.size(0);
1227         // the default shared memory on V100/A100/H100 is 48 KB from
1228         // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x
1229         if ((B + 1) * sizeof(index_t) >= used_shared_bytes) {
1230           matches = false;
1231         }
1232       });
1233   return matches;
1234 }
1235 
1236 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM)                                 \
1237   {                                                                            \
1238     auto [threads, blocks, jagged_dims_tensor] =                               \
1239         check_shape_and_partition_(x_values, x_offsets, y);                    \
1240     blocks.x = div_round_up(x_values.size(0), threads.y);                      \
1241     std::vector<Tensor> x_offsets_contig;                                      \
1242     x_offsets_contig.resize(num_jagged_dim);                                   \
1243     StackArray<index_t*> x_offset_ptrs;                                        \
1244     x_offset_ptrs.ndim = num_jagged_dim;                                       \
1245     StackArray<int64_t> x_offset_sizes;                                        \
1246     x_offset_sizes.ndim = num_jagged_dim;                                      \
1247     for (int d = 0; d < num_jagged_dim; ++d) {                                 \
1248       x_offsets_contig[d] = x_offsets[d].contiguous();                         \
1249       x_offset_ptrs.vals[d] =                                                  \
1250           x_offsets_contig[d].template data_ptr<index_t>();                    \
1251       x_offset_sizes.vals[d] = x_offsets[d].numel();                           \
1252     }                                                                          \
1253     jagged_dense_dense_elementwise_jagged_output_kernel_<                      \
1254         NUM_JAGGED_DIM,                                                        \
1255         index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(    \
1256         x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),      \
1257         x_offset_ptrs,                                                         \
1258         x_offset_sizes,                                                        \
1259         y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),    \
1260         y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),    \
1261         output_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
1262         jagged_dims_tensor,                                                    \
1263         [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/)            \
1264             -> scalar_t { return f(x, y); });                                  \
1265   }
1266 
calc_used_shared_bytes(const int device)1267 inline int calc_used_shared_bytes(const int device) {
1268     int max_shared_bytes;
1269 #ifndef USE_ROCM
1270     C10_CUDA_CHECK(cudaDeviceGetAttribute(
1271         &max_shared_bytes,
1272         cudaDevAttrMaxSharedMemoryPerBlockOptin,
1273         device));
1274 #else
1275     // MI100 has 64 KB local memory (shared memory) per workgroup
1276     max_shared_bytes = 64 << 10;
1277 #endif
1278     int shared_kb = max_shared_bytes >> 10;
1279 #ifndef USE_ROCM
1280     // Use 2/3 of the available GPU shared mem; leave rooms for L1$.
1281     int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
1282     TORCH_CHECK(used_shared_kb > 0);
1283 #else
1284     // MI100 has independent shared mem and L1
1285     int used_shared_kb = shared_kb;
1286 #endif
1287     int used_shared_bytes = used_shared_kb << 10;
1288     return used_shared_bytes;
1289 }
1290 
1291 template <typename index_t>
set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes)1292 inline void set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes) {
1293 #ifndef USE_ROCM
1294     C10_CUDA_CHECK(cudaFuncSetAttribute(
1295         jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
1296             index_t>,
1297         cudaFuncAttributeMaxDynamicSharedMemorySize,
1298         used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB
1299 #endif
1300 }
1301 
1302 ///@addtogroup jagged-tensor-ops-cuda
1303 template <typename scalar_t, typename F>
jagged_dense_elementwise_jagged_output_opt_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output_values,F f)1304 void jagged_dense_elementwise_jagged_output_opt_(
1305     const Tensor& x_values,
1306     const std::vector<Tensor>& x_offsets,
1307     const Tensor& y,
1308     const Tensor& output_values,
1309     F f) {
1310   TENSOR_ON_CUDA_GPU(x_values);
1311   for (auto& x_offset : x_offsets) {
1312     TENSOR_ON_CUDA_GPU(x_offset);
1313   }
1314 
1315   const int num_jagged_dim = y.dim() - 2;
1316   TORCH_CHECK(
1317       x_offsets.size() == static_cast<size_t>(num_jagged_dim),
1318       "x_offsets.size(), ",
1319       x_offsets.size(),
1320       " != num_jagged_dim, ",
1321       num_jagged_dim);
1322 
1323   if (y.numel() == 0 || x_values.numel() == 0) {
1324     return;
1325   }
1326 
1327   // Canonicalize y to 3D, collapsing jagged dimensions.
1328   const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
1329   if (jagged_dense_dense_elementwise_jagged_output_matches_opt(
1330           num_jagged_dim,
1331           x_values,
1332           x_offsets,
1333           y_reshaped,
1334           y_reshaped,
1335           output_values)) {
1336     AT_DISPATCH_INDEX_TYPES(
1337         x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] {
1338           auto nnz = output_values.size(0);
1339           auto B = y_reshaped.size(0);
1340           auto E = y_reshaped.size(2);
1341           Tensor t_rows_after_bs = at::empty(
1342               {nnz},
1343               at::TensorOptions().dtype(at::kInt).device(
1344                   at::kCUDA, at::cuda::current_device()));
1345           Tensor t_cols_after_bs = at::empty(
1346               {nnz},
1347               at::TensorOptions().dtype(at::kInt).device(
1348                   at::kCUDA, at::cuda::current_device()));
1349 
1350           // Binary search
1351           size_t dynamic_smem_size = (B + 1) * sizeof(index_t);
1352           auto cur_max_shared_bytes =
1353               at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
1354           if (dynamic_smem_size > cur_max_shared_bytes) {
1355             int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device());
1356             set_max_dynamic_shared_mem_size_for_opt_search_kernel<index_t>(used_shared_bytes);
1357             C10_CUDA_KERNEL_LAUNCH_CHECK();
1358             TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
1359           }
1360           dim3 threads_bs = dim3(1024, 1, 1);
1361           dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
1362           jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
1363               index_t>
1364               <<<blocks_bs,
1365                  threads_bs,
1366                  dynamic_smem_size,
1367                  at::cuda::getCurrentCUDAStream()>>>(
1368                   x_offsets[0]
1369                       .packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
1370                   t_rows_after_bs
1371                       .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1372                   t_cols_after_bs
1373                       .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1374                   nnz,
1375                   B);
1376           C10_CUDA_KERNEL_LAUNCH_CHECK();
1377           // Gather kernel
1378           dim3 threads = dim3(16, 16, 1);
1379           dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1);
1380           if (blocks.y > 65535) {
1381             blocks.y = 65535;
1382           }
1383           jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
1384               index_t>
1385               <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1386                   output_values
1387                       .packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
1388                   x_values
1389                       .packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
1390                   y_reshaped
1391                       .packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
1392                   y_reshaped
1393                       .packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
1394                   t_rows_after_bs
1395                       .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1396                   t_cols_after_bs
1397                       .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1398                   nnz,
1399                   E,
1400                   [f] __device__(__half x, __half y0, __half) -> __half {
1401                     // NB: added the static_casts here
1402                     return static_cast<__half>(
1403                         f(static_cast<scalar_t>(x), static_cast<scalar_t>(y0))
1404                     );
1405                   });
1406           C10_CUDA_KERNEL_LAUNCH_CHECK();
1407         }); // AT_DISPATCH
1408   } else {
1409     JAGGED_TENSOR_DISPATCH_DIMS();
1410     C10_CUDA_KERNEL_LAUNCH_CHECK();
1411   }
1412 }
1413 
_fbgemm_jagged_to_padded_dense_forward(const Tensor & values,TensorList offsets,c10::IntArrayRef max_lengths,const double padding_value)1414 at::Tensor _fbgemm_jagged_to_padded_dense_forward(
1415     const Tensor& values,
1416     TensorList offsets,
1417     c10::IntArrayRef max_lengths,
1418     const double padding_value) {
1419   const size_t num_jagged_dim = offsets.size();
1420   TORCH_CHECK(
1421       max_lengths.size() == num_jagged_dim,
1422       "max_lengths.size(), ",
1423       max_lengths.size(),
1424       " != num_jagged_dim, ",
1425       num_jagged_dim);
1426   at::cuda::OptionalCUDAGuard device_guard;
1427   device_guard.set_index(values.get_device());
1428 
1429   const Tensor values_canonicalized = values.view(
1430       {values.size(0),
1431        std::accumulate(
1432            values.sizes().begin() + 1,
1433            values.sizes().end(),
1434            1,
1435            std::multiplies<size_t>())});
1436   at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)});
1437   padded_values_shape.insert(
1438       padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
1439 
1440   // Canonicalize padded_values by unsqueeze the last dim if the inner dense
1441   // dimension is 1 and folded.
1442   const bool D_folded = values.dim() == 1;
1443   if (!D_folded) {
1444     padded_values_shape.push_back(values.size(-1));
1445   }
1446   Tensor padded_values =
1447       at::empty_symint(padded_values_shape, values.options());
1448   Tensor padded_values_view =
1449       D_folded ? padded_values.unsqueeze(-1) : padded_values;
1450 
1451   AT_DISPATCH_ALL_TYPES_AND2(
1452       at::ScalarType::Half,
1453       at::ScalarType::BFloat16,
1454       values.scalar_type(),
1455       "jagged_to_padded_dense",
1456       [&] {
1457         jagged_dense_elementwise_dense_output_<scalar_t>(
1458             values_canonicalized,
1459             offsets.vec(),
1460             padded_values_view, // dummy not used in the lambda function
1461             padded_values_view,
1462            [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
1463               return x;
1464             },
1465             static_cast<scalar_t>(padding_value));
1466       });
1467 
1468   return padded_values;
1469 }
1470 
1471 #define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE)                          \
1472   AT_DISPATCH_CASE(TYPE, [&] {                                       \
1473     jagged_dense_elementwise_jagged_output_opt_<scalar_t>(           \
1474         values,                                                      \
1475         offsets.vec(),                                               \
1476         dense,                                                       \
1477         output,                                                      \
1478         [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \
1479           return y;                                                  \
1480         });                                                          \
1481   })
1482 
_fbgemm_dense_to_jagged_forward_symint(const Tensor & dense,TensorList offsets,std::optional<at::SymInt> total_L)1483 Tensor _fbgemm_dense_to_jagged_forward_symint(
1484     const Tensor& dense,
1485     TensorList offsets,
1486     std::optional<at::SymInt> total_L) {
1487   // D is the embedding dimension
1488   auto D = dense.size(-1);
1489 
1490   // If total_L is not given then compute it
1491   at::SymInt total_L_computed;
1492   if (total_L.has_value()) {
1493     total_L_computed = total_L.value();
1494   } else {
1495     total_L_computed = (int64_t)offsets.back().max().item<int64_t>();
1496   }
1497   auto values = at::empty_symint({total_L_computed, D}, dense.options());
1498   auto output = at::empty_like(values);
1499 
1500   at::cuda::OptionalCUDAGuard device_guard;
1501   device_guard.set_index(dense.get_device());
1502 
1503   // clang-format off
1504   AT_DISPATCH_SWITCH(
1505       values.scalar_type(),
1506       "dense_to_jagged_gpu_op_forward",
1507       DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half)
1508       // NB: removed this to build
1509       // DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int)
1510       AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
1511           at::ScalarType::Long,
1512           at::ScalarType::BFloat16,
1513           [&] {
1514             jagged_dense_elementwise_jagged_output_<scalar_t>(
1515                 values,
1516                 offsets.vec(),
1517                 dense,
1518                 output,
1519                 [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
1520                   return y;
1521                 }); // device lambda
1522           } // lambda
1523           ) // CASE_FLOATING_TYPES_AND
1524   ); // SWITCH
1525   // clang-format on
1526 
1527 #undef DISPATCH_DENSE_TO_JAGGED_CASE
1528 
1529   return output;
1530 }
1531 
1532 } // namespace native
1533 } // namespace at
1534