1 #include <cuda_fp16.h>
2 #include <type_traits>
3
4 #include <ATen/ATen.h>
5 #include <ATen/Dispatch.h>
6
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/detail/IndexUtils.cuh>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/cuda/MemoryAccess.cuh>
12 #include <ATen/native/cuda/PersistentSoftmax.cuh>
13 #include <ATen/native/cuda/block_reduce.cuh>
14
15 #include <c10/cuda/CUDAGuard.h>
16 #include <c10/cuda/CUDAMathCompat.h>
17 #include <c10/cuda/CUDAStream.h>
18
19 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
20 #include <ATen/native/nested/NestedTensorUtils.h>
21
22 #ifndef USE_ROCM
23 #ifndef _WIN32
24 #include <cutlass/gemm/device/default_gemm_configuration.h>
25 #include <cutlass/gemm/device/gemm_grouped.h>
26 #include <cutlass/gemm/kernel/default_gemm_grouped.h>
27 #endif
28 #endif
29
30 #include <ATen/NestedTensorImpl.h>
31
32 #define BLOCK_DIM 256
33 #define GRID_DIM_Y 16
34
35 namespace at {
36 namespace native {
37
38 template <typename T>
remove_padding_transform0213_2(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)39 __global__ void remove_padding_transform0213_2(
40 const T* input,
41 T* output,
42 const int* offsets,
43 const int* input_sizes,
44 const int* output_sizes,
45 int output_dim,
46 const int batch_size) {
47 const int batch_id = blockIdx.x;
48 const int grid_id = blockIdx.y;
49 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
50 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
51 const int offset = offsets[batch_id];
52 const int* sizes_i = output_sizes + batch_id * output_dim;
53 const int numel_i = sizes_i[0] * sizes_i[1];
54 int input_offset =
55 batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3];
56 for (int ii = 0; ii < (numel_i / grainsize); ii++) {
57 const int i = ii * grainsize + tid;
58 const int i2 = i / sizes_i[1];
59 const int i13 = i % sizes_i[1];
60 const int i1 = i13 / (sizes_i[1] / input_sizes[1]);
61 const int i3 = i13 % (sizes_i[1] / input_sizes[1]);
62
63 output[offset + i] = input
64 [input_offset + i1 * input_sizes[2] * input_sizes[3] +
65 i2 * input_sizes[3] + i3];
66 }
67 const int i = (numel_i / grainsize) * grainsize + tid;
68 if (i < numel_i) {
69 const int i2 = i / sizes_i[1];
70 const int i13 = i % sizes_i[1];
71 const int i1 = i13 / (sizes_i[1] / input_sizes[1]);
72 const int i3 = i13 % (sizes_i[1] / input_sizes[1]);
73 output[offset + i] = input
74 [input_offset + i1 * input_sizes[2] * input_sizes[3] +
75 i2 * input_sizes[3] + i3];
76 }
77 }
78
79 template <typename T>
remove_padding_2(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)80 __global__ void remove_padding_2(
81 const T* input,
82 T* output,
83 const int* offsets,
84 const int* input_sizes,
85 const int* output_sizes,
86 int output_dim,
87 const int batch_size) {
88 const int batch_id = blockIdx.x;
89 const int grid_id = blockIdx.y;
90 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
91 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
92 const int offset = offsets[batch_id];
93 const int* sizes_i = output_sizes + batch_id * output_dim;
94 const int numel_i = sizes_i[0] * sizes_i[1];
95 int input_offset = batch_id * input_sizes[1] * input_sizes[2];
96 for (int ii = 0; ii < (numel_i / grainsize); ii++) {
97 const int i = ii * grainsize + tid;
98 const int i0 = i / sizes_i[1];
99 const int i1 = i % sizes_i[1];
100 const int i0_offset = i0 * input_sizes[2];
101 output[offset + i] = input[input_offset + i0_offset + i1];
102 }
103 const int i = (numel_i / grainsize) * grainsize + tid;
104 if (i < numel_i) {
105 const int i0 = i / sizes_i[1];
106 const int i1 = i % sizes_i[1];
107 const int i0_offset = i0 * input_sizes[2];
108 output[offset + i] = input[input_offset + i0_offset + i1];
109 }
110 }
111
112 template <typename T>
remove_padding(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)113 __global__ void remove_padding(
114 const T* input,
115 T* output,
116 const int* offsets,
117 const int* input_sizes,
118 const int* output_sizes,
119 int output_dim,
120 const int batch_size) {
121 const int batch_id = blockIdx.x;
122 const int grid_id = blockIdx.y;
123 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
124 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
125 const int offset = offsets[batch_id];
126 const int* sizes_i = output_sizes + batch_id * output_dim;
127 const int numel_i = sizes_i[0] * sizes_i[1] * sizes_i[2];
128 int input_offset =
129 batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3];
130 for (int ii = 0; ii < (numel_i / grainsize); ii++) {
131 const int i = ii * grainsize + tid;
132 const int i0 = i / (sizes_i[1] * sizes_i[2]);
133 const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2];
134 const int i2 = i % sizes_i[2];
135 const int i0_offset = i0 * input_sizes[2] * input_sizes[3];
136 const int i1_offset = i1 * input_sizes[3];
137 output[offset + i] = input[input_offset + i0_offset + i1_offset + i2];
138 }
139 const int i = (numel_i / grainsize) * grainsize + tid;
140 if (i < numel_i) {
141 const int i0 = i / (sizes_i[1] * sizes_i[2]);
142 const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2];
143 const int i2 = i % sizes_i[2];
144 const int i0_offset = i0 * input_sizes[2] * input_sizes[3];
145 const int i1_offset = i1 * input_sizes[3];
146 output[offset + i] = input[input_offset + i0_offset + i1_offset + i2];
147 }
148 }
149
150 template <typename T>
remove_padding_kernelLauncher(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)151 void remove_padding_kernelLauncher(
152 const T* input,
153 T* output,
154 const int* offsets,
155 const int* input_sizes,
156 const int* output_sizes,
157 int output_dim,
158 const int batch_size) {
159 dim3 grid;
160 grid.x = batch_size;
161 grid.y = GRID_DIM_Y;
162 at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
163 if (output_dim == 2) {
164 remove_padding_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
165 input,
166 output,
167 offsets,
168 input_sizes,
169 output_sizes,
170 output_dim,
171 batch_size);
172 } else {
173 remove_padding<T><<<grid, BLOCK_DIM, 0, stream>>>(
174 input,
175 output,
176 offsets,
177 input_sizes,
178 output_sizes,
179 output_dim,
180 batch_size);
181 }
182 }
183
184 template <typename T>
remove_padding_transform0213_kernelLauncher(const T * input,T * output,const int * offsets,const int * input_sizes,const int * output_sizes,int output_dim,const int batch_size)185 void remove_padding_transform0213_kernelLauncher(
186 const T* input,
187 T* output,
188 const int* offsets,
189 const int* input_sizes,
190 const int* output_sizes,
191 int output_dim,
192 const int batch_size) {
193 dim3 grid;
194 grid.x = batch_size;
195 grid.y = GRID_DIM_Y;
196 at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
197 TORCH_CHECK(
198 output_dim == 2,
199 "remove padding transform0213 only support output dim == 2");
200
201 remove_padding_transform0213_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
202 input,
203 output,
204 offsets,
205 input_sizes,
206 output_sizes,
207 output_dim,
208 batch_size);
209 }
210
211 template void remove_padding_kernelLauncher<float>(
212 const float* input,
213 float* output,
214 const int* offsets,
215 const int* input_sizes,
216 const int* output_sizes,
217 int output_dim,
218 const int batch_size);
219
220 template void remove_padding_kernelLauncher<c10::Half>(
221 const c10::Half* input,
222 c10::Half* output,
223 const int* offsets,
224 const int* input_sizes,
225 const int* output_sizes,
226 int output_dim,
227 const int batch_size);
228
229 template void remove_padding_transform0213_kernelLauncher<float>(
230 const float* input,
231 float* output,
232 const int* offsets,
233 const int* input_sizes,
234 const int* output_sizes,
235 int output_dim,
236 const int batch_size);
237
238 template void remove_padding_transform0213_kernelLauncher<c10::Half>(
239 const c10::Half* input,
240 c10::Half* output,
241 const int* offsets,
242 const int* input_sizes,
243 const int* output_sizes,
244 int output_dim,
245 const int batch_size);
246
247 template <typename T>
add_padding_1(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,const int batch_size)248 __global__ void add_padding_1(
249 const T* input,
250 T* output,
251 T padding_value,
252 const int* offsets,
253 const int* input_sizes,
254 int input_dim,
255 int output_sizes_1,
256 const int batch_size) {
257 const int batch_id = blockIdx.x;
258 const int grid_id = blockIdx.y;
259 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
260 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
261 const int* sizes_i = input_sizes + batch_id * input_dim;
262 const int batch_output_offset = batch_id * output_sizes_1;
263 for (int ii = 0; ii < (output_sizes_1 / grainsize); ii++) {
264 const int i = ii * grainsize + tid;
265 const int output_offset = batch_output_offset + i;
266 if (batch_id < batch_size && i < sizes_i[0]) {
267 const int batch_input_offset = offsets[batch_id];
268 output[output_offset] = input[batch_input_offset + i];
269 } else {
270 output[output_offset] = padding_value;
271 }
272 }
273 const int i = (output_sizes_1 / grainsize) * grainsize + tid;
274 if (i < output_sizes_1) {
275 const int output_offset = batch_output_offset + i;
276 if (batch_id < batch_size && (i < sizes_i[0])) {
277 const int batch_input_offset = offsets[batch_id];
278 output[output_offset] = input[batch_input_offset + i];
279 } else {
280 output[output_offset] = padding_value;
281 }
282 }
283 }
284
285 template <typename T>
add_padding_2(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,int output_sizes_2,const int batch_size)286 __global__ void add_padding_2(
287 const T* input,
288 T* output,
289 T padding_value,
290 const int* offsets,
291 const int* input_sizes,
292 int input_dim,
293 int output_sizes_1,
294 int output_sizes_2,
295 const int batch_size) {
296 const int batch_id = blockIdx.x;
297 const int grid_id = blockIdx.y;
298 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
299 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
300 const int* sizes_i = input_sizes + batch_id * input_dim;
301 const int output_offset = batch_id * output_sizes_1 * output_sizes_2;
302 const int output_numel = output_sizes_1 * output_sizes_2;
303 for (int ii = 0; ii < (output_numel / grainsize); ii++) {
304 const int i = ii * grainsize + tid;
305 const int i0 = i / (output_sizes_2);
306 const int i1 = i - i0 * output_sizes_2;
307 if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1]) {
308 const int offset = offsets[batch_id];
309 const int input_offset = offset + i0 * sizes_i[1] + i1;
310 output[output_offset + i] = input[input_offset];
311 } else {
312 output[output_offset + i] = padding_value;
313 }
314 }
315 const int i = (output_numel / grainsize) * grainsize + tid;
316 if (i < output_numel) {
317 const int i0 = i / (output_sizes_2);
318 const int i1 = i - i0 * output_sizes_2;
319 if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1]) {
320 const int offset = offsets[batch_id];
321 const int input_offset = offset + i0 * sizes_i[1] + i1;
322 output[output_offset + i] = input[input_offset];
323 } else {
324 output[output_offset + i] = padding_value;
325 }
326 }
327 }
328
329 template <typename T>
add_padding_3(const T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,int output_sizes_1,int output_sizes_2,int output_sizes_3,const int batch_size)330 __global__ void add_padding_3(
331 const T* input,
332 T* output,
333 T padding_value,
334 const int* offsets,
335 const int* input_sizes,
336 int input_dim,
337 int output_sizes_1,
338 int output_sizes_2,
339 int output_sizes_3,
340 const int batch_size) {
341 const int batch_id = blockIdx.x;
342 const int grid_id = blockIdx.y;
343 const int tid = threadIdx.x + grid_id * BLOCK_DIM;
344 const int grainsize = GRID_DIM_Y * BLOCK_DIM;
345 const int* sizes_i = input_sizes + batch_id * input_dim;
346 const int output_offset =
347 batch_id * output_sizes_1 * output_sizes_2 * output_sizes_3;
348 const int output_numel = output_sizes_1 * output_sizes_2 * output_sizes_3;
349 for (int ii = 0; ii < (output_numel / grainsize); ii++) {
350 const int i = ii * grainsize + tid;
351 const int i0 = i / (output_sizes_2 * output_sizes_3);
352 const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
353 const int i2 = i % output_sizes_3;
354 if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
355 i2 < sizes_i[2]) {
356 const int offset = offsets[batch_id];
357 const int input_offset =
358 offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
359 output[output_offset + i] = input[input_offset];
360 } else {
361 output[output_offset + i] = padding_value;
362 }
363 }
364 const int i = (output_numel / grainsize) * grainsize + tid;
365 if (i < output_numel) {
366 const int i0 = i / (output_sizes_2 * output_sizes_3);
367 const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
368 const int i2 = i % output_sizes_3;
369 if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
370 i2 < sizes_i[2]) {
371 const int offset = offsets[batch_id];
372 const int input_offset =
373 offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
374 output[output_offset + i] = input[input_offset];
375 } else {
376 output[output_offset + i] = padding_value;
377 }
378 }
379 }
380
381 template <typename T>
add_padding_kernelLauncher(T * input,T * output,T padding_value,const int * offsets,const int * input_sizes,int input_dim,const std::vector<int64_t> & output_sizes,const int batch_size,const int output_batch_size)382 void add_padding_kernelLauncher(
383 T* input, // [batch_size x None]
384 T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
385 T padding_value,
386 const int* offsets,
387 const int* input_sizes,
388 int input_dim,
389 const std::vector<int64_t>& output_sizes,
390 const int batch_size,
391 const int output_batch_size) {
392 at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
393 dim3 grid;
394 grid.x = output_batch_size;
395 grid.y = GRID_DIM_Y;
396 if (input_dim == 1) {
397 add_padding_1<T><<<grid, BLOCK_DIM, 0, stream>>>(
398 input,
399 output,
400 padding_value,
401 offsets,
402 input_sizes,
403 input_dim,
404 output_sizes[1],
405 batch_size);
406 }
407 if (input_dim == 2) {
408 add_padding_2<T><<<grid, BLOCK_DIM, 0, stream>>>(
409 input,
410 output,
411 padding_value,
412 offsets,
413 input_sizes,
414 input_dim,
415 output_sizes[1],
416 output_sizes[2],
417 batch_size);
418 }
419 if (input_dim == 3) {
420 add_padding_3<T><<<grid, BLOCK_DIM, 0, stream>>>(
421 input,
422 output,
423 padding_value,
424 offsets,
425 input_sizes,
426 input_dim,
427 output_sizes[1],
428 output_sizes[2],
429 output_sizes[3],
430 batch_size);
431 }
432 }
433
434 template void add_padding_kernelLauncher<double>(
435 double* input,
436 double* output,
437 double padding_value,
438 const int* offsets,
439 const int* input_sizes,
440 int input_dim,
441 const std::vector<int64_t>& output_sizes,
442 const int batch_size,
443 const int output_batch_size);
444
445 template void add_padding_kernelLauncher<float>(
446 float* input,
447 float* output,
448 float padding_value,
449 const int* offsets,
450 const int* input_sizes,
451 int input_dim,
452 const std::vector<int64_t>& output_sizes,
453 const int batch_size,
454 const int output_batch_size);
455
456 template void add_padding_kernelLauncher<c10::Half>(
457 c10::Half* input,
458 c10::Half* output,
459 c10::Half padding_value,
460 const int* offsets,
461 const int* input_sizes,
462 int input_dim,
463 const std::vector<int64_t>& output_sizes,
464 const int batch_size,
465 const int output_batch_size);
466
467 // NB: The following code covers jagged <-> padded dense conversions and was lifted
468 // from fbgemm_gpu. For more details, see
469 // https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/src/jagged_tensor_ops
470
471 // Passing lambda exp argument by value instead of by reference to avoid
472 // "internal compiler error: in maybe_undo_parenthesized_ref" error for specific
473 // compiler version.
474 #define JAGGED_TENSOR_DISPATCH_DIMS() \
475 AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \
476 switch (num_jagged_dim) { \
477 case 1: \
478 INVOKE_KERNEL_WITH_DIM(1); \
479 break; \
480 case 2: \
481 INVOKE_KERNEL_WITH_DIM(2); \
482 break; \
483 case 3: \
484 INVOKE_KERNEL_WITH_DIM(3); \
485 break; \
486 case 4: \
487 INVOKE_KERNEL_WITH_DIM(4); \
488 break; \
489 case 5: \
490 INVOKE_KERNEL_WITH_DIM(5); \
491 break; \
492 default: \
493 TORCH_CHECK( \
494 false, "unsupported number of jagged dim ", num_jagged_dim); \
495 } \
496 });
497
torch_tensor_device_name(const at::Tensor & ten)498 inline std::string torch_tensor_device_name(const at::Tensor& ten) {
499 return c10::DeviceTypeName(ten.device().type());
500 }
501
torch_tensor_device_name(const std::optional<at::Tensor> & ten)502 inline std::string torch_tensor_device_name(
503 const std::optional<at::Tensor>& ten) {
504 if (ten.has_value()) {
505 return torch_tensor_device_name(ten.value());
506 } else {
507 return "N/A";
508 }
509 }
510
torch_tensor_on_cuda_gpu_check(const at::Tensor & ten)511 inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) {
512 return ten.is_cuda();
513 }
514
torch_tensor_on_cuda_gpu_check(const std::optional<at::Tensor> & ten)515 inline bool torch_tensor_on_cuda_gpu_check(
516 const std::optional<at::Tensor>& ten) {
517 return !ten.has_value() || torch_tensor_on_cuda_gpu_check(ten.value());
518 }
519
520 #define TENSOR_ON_CUDA_GPU(x) \
521 TORCH_CHECK( \
522 torch_tensor_on_cuda_gpu_check(x), \
523 #x " must be a CUDA tensor; it is currently on device ", \
524 torch_tensor_device_name(x))
525
526 // A wrapper class for passing dynamically sized dimension information (e.g.
527 // tensor.dims()) from the host to device.
528 constexpr size_t kStackArrayMaxDims = 5;
529
530 template <typename T>
531 struct StackArray {
532 T vals[kStackArrayMaxDims];
533 size_t ndim;
534 };
535
536 // Warp size
537 #ifdef USE_ROCM
538 static constexpr int32_t kWarpSize = 64;
539 #else
540 static constexpr int32_t kWarpSize = 32;
541 #endif
542 // Max thread num in one thread block
543 static constexpr int32_t kMaxThreads = 1024;
544
545 #define DEVICE_INLINE __device__ C10_ALWAYS_INLINE
546
div_round_up(int32_t a,int32_t b)547 __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) {
548 return (a + b - 1) / b;
549 }
550
round_down(int32_t a,int32_t b)551 __host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) {
552 return a / b * b;
553 }
554
check_shape_and_partition_(const Tensor & values,const std::vector<Tensor> & offsets,const Tensor & dense_tensor)555 inline std::tuple<dim3, dim3, StackArray<int64_t>> check_shape_and_partition_(
556 const Tensor& values,
557 const std::vector<Tensor>& offsets,
558 const Tensor& dense_tensor) {
559 const int outer_dense_size = dense_tensor.size(0);
560 TORCH_CHECK(
561 outer_dense_size == offsets[0].numel() - 1,
562 "outer_dense_size, ",
563 outer_dense_size,
564 " != offsets[0].numel() - 1, ",
565 offsets[0].numel() - 1);
566 const int inner_dense_size = dense_tensor.size(-1);
567 TORCH_CHECK(
568 inner_dense_size == values.size(-1),
569 "inner_dense_size, ",
570 inner_dense_size,
571 " != values.size(-1), ",
572 values.size(-1));
573 const int jagged_folded_size =
574 dense_tensor.numel() / (outer_dense_size * inner_dense_size);
575
576 const int threads_x =
577 inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size;
578 const int threads_y = kMaxThreads / kWarpSize;
579 const dim3 blocks(
580 div_round_up(outer_dense_size * jagged_folded_size, threads_y));
581
582 StackArray<int64_t> jagged_dims_tensor;
583 const int num_jagged_dim = dense_tensor.dim() - 2;
584 TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims);
585 jagged_dims_tensor.ndim = num_jagged_dim;
586 std::memcpy(
587 &(jagged_dims_tensor.vals[0]),
588 dense_tensor.sizes().data() + 1,
589 num_jagged_dim * sizeof(int64_t));
590 return {dim3(threads_x, threads_y), blocks, jagged_dims_tensor};
591 }
592
593 template <int NUM_JAGGED_DIM, typename index_t>
walk_down_tensor_storage_tree_(int & offset,const int flattened_jagged_idx,const StackArray<int64_t> & jagged_dims,const StackArray<index_t * > & x_offsets)594 DEVICE_INLINE bool walk_down_tensor_storage_tree_(
595 int& offset,
596 const int flattened_jagged_idx,
597 const StackArray<int64_t>& jagged_dims,
598 const StackArray<index_t*>& x_offsets) {
599 // compute coorindates
600 int jagged_coords[NUM_JAGGED_DIM];
601 int j_temp = flattened_jagged_idx;
602 #pragma unroll
603 for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) {
604 const int jagged_size = jagged_dims.vals[d];
605 jagged_coords[d] = j_temp % jagged_size;
606 j_temp /= jagged_size;
607 }
608
609 // walk down the tree
610 bool is_zero = false;
611 #pragma unroll
612 for (int d = 0; d < NUM_JAGGED_DIM; ++d) {
613 const int begin = x_offsets.vals[d][offset];
614 const int end = x_offsets.vals[d][offset + 1];
615 if (jagged_coords[d] >= end - begin) {
616 is_zero = true;
617 break;
618 }
619 offset = begin + jagged_coords[d];
620 }
621 return is_zero;
622 }
623
624 // output = f(x, y) where x is jagged, y is dense, and output is dense.
625 // A generic elementwise operation between a jagged tensor and a dense tensor
626 // This kernel assumes jagged dims are clustered together, preceded by outer
627 // dense dimensions and followed by inner dense dimensions.
628 // The outer/inner dense dimensions, and jagged dimensions in between are
629 // assumed to be folded so physically the dense tensor is 3D and the value of
630 // jagged tensor is 2D.
631 // To support arbitrary number of jagged dimensions, we pass a vector of
632 // pointers to offset tensors (this is ugly and probably we can use nested
633 // tensor here).
634 // This kernel parallelizes the (folded) inner dense dimension across
635 // blockDim.x so the inner dense dimension should be similar to or bigger than
636 // warp size.
637 // We rely on compiler unrolling the compiler time constant NUM_JAGGED_DIM.
638 template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
639 __global__
__launch_bounds__(kMaxThreads)640 __launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_(
641 const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
642 x_values,
643 StackArray<index_t*> x_offsets,
644 const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y,
645 at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> output,
646 StackArray<int64_t> jagged_dims,
647 F f,
648 const scalar_t padding_value) {
649 const int outer_dense_size = y.size(0);
650 const int jagged_folded_size = y.size(1);
651 const int inner_dense_size = y.size(2);
652
653 const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y;
654 const int outer_stride = gridDim.x * blockDim.y;
655 for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size;
656 outer += outer_stride) {
657 const int oidx = outer / jagged_folded_size;
658 const int jidx = outer % jagged_folded_size;
659
660 int offset = oidx;
661 const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>(
662 offset, jidx, jagged_dims, x_offsets);
663
664 if (is_zero) {
665 int iidx;
666 for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
667 iidx += blockDim.x) {
668 output[oidx][jidx][2 * iidx] =
669 f(padding_value, y[oidx][jidx][2 * iidx]);
670 output[oidx][jidx][2 * iidx + 1] =
671 f(padding_value, y[oidx][jidx][2 * iidx + 1]);
672 }
673 if (iidx * 2 + 1 == inner_dense_size) {
674 output[oidx][jidx][2 * iidx] =
675 f(padding_value, y[oidx][jidx][2 * iidx]);
676 }
677 } else {
678 int iidx;
679 for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
680 iidx += blockDim.x) {
681 output[oidx][jidx][2 * iidx] =
682 f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
683 output[oidx][jidx][2 * iidx + 1] =
684 f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]);
685 }
686 if (iidx * 2 + 1 == inner_dense_size) {
687 output[oidx][jidx][2 * iidx] =
688 f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
689 }
690 }
691 }
692 }
693
694 template <typename scalar_t, typename F>
jagged_dense_elementwise_dense_output_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output,F f,const scalar_t padding_value=static_cast<scalar_t> (0))695 void jagged_dense_elementwise_dense_output_(
696 const Tensor& x_values,
697 const std::vector<Tensor>& x_offsets,
698 const Tensor& y,
699 const Tensor& output,
700 F f,
701 const scalar_t padding_value = static_cast<scalar_t>(0)) {
702 TENSOR_ON_CUDA_GPU(x_values);
703 for (auto& x_offset : x_offsets) {
704 TENSOR_ON_CUDA_GPU(x_offset);
705 }
706
707 const int num_jagged_dim = y.dim() - 2;
708 TORCH_CHECK(
709 x_offsets.size() == static_cast<size_t>(num_jagged_dim),
710 "x_offsets.size(), ",
711 x_offsets.size(),
712 " != num_jagged_dim ",
713 num_jagged_dim);
714
715 if (y.numel() == 0) {
716 return;
717 }
718
719 dim3 threads, blocks;
720 StackArray<int64_t> jagged_dims_tensor;
721 std::tie(threads, blocks, jagged_dims_tensor) =
722 check_shape_and_partition_(x_values, x_offsets, y);
723
724 // Canonicalize y and output to 3D, collapsing jagged dimensions.
725 const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
726 Tensor output_reshaped = output.view(y_reshaped.sizes());
727
728 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
729 { \
730 std::vector<Tensor> x_offsets_contig; \
731 x_offsets_contig.resize(num_jagged_dim); \
732 StackArray<index_t*> x_offset_ptrs; \
733 x_offset_ptrs.ndim = num_jagged_dim; \
734 for (int d = 0; d < num_jagged_dim; ++d) { \
735 x_offsets_contig[d] = x_offsets[d].contiguous(); \
736 x_offset_ptrs.vals[d] = \
737 x_offsets_contig[d].template data_ptr<index_t>(); \
738 } \
739 jagged_dense_elementwise_dense_output_kernel_<NUM_JAGGED_DIM, index_t> \
740 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
741 x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
742 x_offset_ptrs, \
743 y_reshaped \
744 .packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
745 output_reshaped \
746 .packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
747 jagged_dims_tensor, \
748 f, \
749 padding_value); \
750 }
751
752 JAGGED_TENSOR_DISPATCH_DIMS();
753 C10_CUDA_KERNEL_LAUNCH_CHECK();
754
755 #undef INVOKE_KERNEL_WITH_DIM
756 }
757
758 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
759 { \
760 auto [threads, blocks, jagged_dims_tensor] = \
761 check_shape_and_partition_(x_values, x_offsets, y); \
762 blocks.x = div_round_up(x_values.size(0), threads.y); \
763 std::vector<Tensor> x_offsets_contig; \
764 x_offsets_contig.resize(num_jagged_dim); \
765 StackArray<index_t*> x_offset_ptrs; \
766 x_offset_ptrs.ndim = num_jagged_dim; \
767 StackArray<int64_t> x_offset_sizes; \
768 x_offset_sizes.ndim = num_jagged_dim; \
769 for (int d = 0; d < num_jagged_dim; ++d) { \
770 x_offsets_contig[d] = x_offsets[d].contiguous(); \
771 x_offset_ptrs.vals[d] = \
772 x_offsets_contig[d].template data_ptr<index_t>(); \
773 x_offset_sizes.vals[d] = x_offsets[d].numel(); \
774 } \
775 jagged_dense_dense_elementwise_jagged_output_kernel_< \
776 NUM_JAGGED_DIM, \
777 index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
778 x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
779 x_offset_ptrs, \
780 x_offset_sizes, \
781 y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
782 y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
783 output_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
784 jagged_dims_tensor, \
785 [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \
786 -> scalar_t { return f(x, y); }); \
787 }
788
789 template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
790 __global__
__launch_bounds__(kMaxThreads)791 __launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_(
792 const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
793 x_values,
794 StackArray<index_t*> x_offsets,
795 StackArray<int64_t> x_offsets_sizes,
796 const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_0,
797 const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_1,
798 at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
799 output_values,
800 StackArray<int64_t> jagged_dims,
801 F f) {
802 const int outer_dense_size = y_0.size(0);
803 const int inner_dense_size = y_0.size(2);
804 const int nnz = x_values.size(0);
805
806 const int offset_begin = blockIdx.x * blockDim.y + threadIdx.y;
807 const int offset_stride = gridDim.x * blockDim.y;
808 for (int offset = offset_begin; offset < nnz; offset += offset_stride) {
809 int offset_temp = offset;
810 int jidx = 0;
811 bool truncated = false;
812 int dim_prod = 1;
813 #pragma unroll
814 for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) {
815 // Binary search the first that is bigger than offset
816 int count = x_offsets_sizes.vals[d] - 1;
817 int first = 1;
818 while (count > 0) {
819 int idx = first;
820 int step = count / 2;
821 idx += step;
822 if (x_offsets.vals[d][idx] <= offset_temp) {
823 first = ++idx;
824 count -= step + 1;
825 } else {
826 count = step;
827 }
828 }
829
830 --first;
831 int coord = offset_temp - x_offsets.vals[d][first];
832 if (coord >= jagged_dims.vals[d]) {
833 truncated = true;
834 break;
835 }
836 jidx += coord * dim_prod;
837 dim_prod *= jagged_dims.vals[d];
838 offset_temp = first;
839 }
840
841 if (offset_temp >= outer_dense_size) {
842 // This can happen when values have more elements than the last element of
843 // offset
844 truncated = true;
845 }
846 if (!truncated) {
847 const int oidx = offset_temp;
848 int iidx;
849 for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
850 iidx += blockDim.x) {
851 output_values[offset][2 * iidx] =
852 f(x_values[offset][2 * iidx],
853 y_0[oidx][jidx][2 * iidx],
854 y_1[oidx][jidx][2 * iidx]);
855 output_values[offset][2 * iidx + 1] =
856 f(x_values[offset][2 * iidx + 1],
857 y_0[oidx][jidx][2 * iidx + 1],
858 y_1[oidx][jidx][2 * iidx + 1]);
859 }
860 if (iidx * 2 + 1 == inner_dense_size) {
861 output_values[offset][2 * iidx] =
862 f(x_values[offset][2 * iidx],
863 y_0[oidx][jidx][2 * iidx],
864 y_1[oidx][jidx][2 * iidx]);
865 }
866 } else {
867 int iidx;
868 for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
869 iidx += blockDim.x) {
870 output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0);
871 output_values[offset][2 * iidx + 1] =
872 f(x_values[offset][2 * iidx + 1], 0, 0);
873 }
874 if (iidx * 2 + 1 == inner_dense_size) {
875 output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0);
876 }
877 }
878 }
879 }
880
881 ///@addtogroup jagged-tensor-ops-cuda
882 template <typename scalar_t, typename F>
jagged_dense_elementwise_jagged_output_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output_values,F f)883 void jagged_dense_elementwise_jagged_output_(
884 const Tensor& x_values,
885 const std::vector<Tensor>& x_offsets,
886 const Tensor& y,
887 const Tensor& output_values,
888 F f) {
889 TENSOR_ON_CUDA_GPU(x_values);
890 for (auto& x_offset : x_offsets) {
891 TENSOR_ON_CUDA_GPU(x_offset);
892 }
893
894 const int num_jagged_dim = y.dim() - 2;
895 TORCH_CHECK(
896 x_offsets.size() == static_cast<size_t>(num_jagged_dim),
897 "x_offsets.size(), ",
898 x_offsets.size(),
899 " != num_jagged_dim, ",
900 num_jagged_dim);
901
902 if (y.numel() == 0 || x_values.numel() == 0) {
903 return;
904 }
905
906 // Canonicalize y to 3D, collapsing jagged dimensions.
907 const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
908
909 JAGGED_TENSOR_DISPATCH_DIMS();
910 C10_CUDA_KERNEL_LAUNCH_CHECK();
911 }
912
913 #undef INVOKE_KERNEL_WITH_DIM
914
915 template <typename T>
916 struct SharedMemory;
917
918 template <>
919 struct SharedMemory<int64_t> {
getPointerat::native::SharedMemory920 __device__ int64_t* getPointer() {
921 extern __shared__ int64_t s_int64_t[];
922 return s_int64_t;
923 }
924 };
925
926 template <>
927 struct SharedMemory<int32_t> {
getPointerat::native::SharedMemory928 __device__ int32_t* getPointer() {
929 extern __shared__ int32_t s_int32_t[];
930 return s_int32_t;
931 }
932 };
933
934 template <typename index_t>
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_(const at::PackedTensorAccessor32<index_t,1,at::RestrictPtrTraits> offsets,at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> rows,at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> cols,int nnz,int B)935 __global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_(
936 const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
937 at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
938 at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
939 int nnz,
940 int B) {
941 struct SharedMemory<index_t> smem;
942 index_t* offsets_sh = smem.getPointer();
943
944 for (int i = threadIdx.x; i < B + 1; i += blockDim.x) {
945 offsets_sh[i] = offsets[i];
946 }
947 __syncthreads();
948 int row = threadIdx.x + blockIdx.x * blockDim.x;
949 if (row >= nnz)
950 return;
951 int first = -1;
952 int count = B - 1;
953 first = 1;
954 while (count > 0) {
955 int idx = first;
956 int step = count / 2;
957 idx += step;
958 if (offsets_sh[idx] <= row) {
959 first = ++idx;
960 count -= step + 1;
961 } else {
962 count = step;
963 }
964 }
965 --first;
966
967 int dense_row = first;
968 int offset = offsets_sh[dense_row];
969 int dense_col = row - offset;
970 rows[row] = dense_row;
971 cols[row] = dense_col;
972 }
973
974 struct VecType128 {
975 typedef float4 TType; // Transaction Type
976 typedef struct __align__(16) {
977 __half a, b, c, d, w, x, y, z;
978 }
979 half8;
980
981 union Data {
982 half8 val;
983 TType mask;
984 } data;
985
VecType128at::native::VecType128986 __device__ VecType128() {
987 data.mask = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
988 }
989 };
990
991 struct VecType64 {
992 typedef float2 TType; // Transaction Type
993 typedef struct __align__(8) {
994 __half a, b, c, d;
995 }
996 half4;
997
998 union Data {
999 half4 val;
1000 TType mask;
1001 } data;
1002
VecType64at::native::VecType641003 __device__ VecType64() {
1004 data.mask = make_float2(0.0f, 0.0f);
1005 }
1006 };
1007
1008 struct VecType32 {
1009 typedef float TType; // Transaction Type
1010
1011 union Data {
1012 __half2 val;
1013 TType mask;
1014 } data;
1015
VecType32at::native::VecType321016 __device__ VecType32() {
1017 data.mask = 0.0f;
1018 }
1019 };
1020
1021 template <typename F>
f128(VecType128 & v_out,const VecType128 & x,const VecType128 & y0,const VecType128 & y1,F f)1022 __device__ void f128(
1023 VecType128& v_out,
1024 const VecType128& x,
1025 const VecType128& y0,
1026 const VecType128& y1,
1027 F f) {
1028 v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a);
1029 v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b);
1030 v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c);
1031 v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d);
1032 v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w);
1033 v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x);
1034 v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y);
1035 v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z);
1036 }
1037
1038 template <typename F>
f64(VecType64 & v_out,const VecType64 & x,const VecType64 & y0,const VecType64 & y1,F f)1039 __device__ void f64(
1040 VecType64& v_out,
1041 const VecType64& x,
1042 const VecType64& y0,
1043 const VecType64& y1,
1044 F f) {
1045 v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a);
1046 v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b);
1047 v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c);
1048 v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d);
1049 }
1050
1051 template <typename F>
f32(VecType32 & v_out,const VecType32 & x,const VecType32 & y0,const VecType32 & y1,F f)1052 __device__ void f32(
1053 VecType32& v_out,
1054 const VecType32& x,
1055 const VecType32& y0,
1056 const VecType32& y1,
1057 F f) {
1058 v_out.data.val = __halves2half2(
1059 f(__low2half(x.data.val),
1060 __low2half(y0.data.val),
1061 __low2half(y1.data.val)),
1062 f(__high2half(x.data.val),
1063 __high2half(y0.data.val),
1064 __high2half(y1.data.val)));
1065 }
1066
1067 template <typename index_t, typename F>
jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(at::PackedTensorAccessor32<c10::Half,2,at::RestrictPtrTraits> values,const at::PackedTensorAccessor32<c10::Half,2,at::RestrictPtrTraits> x_values,const at::PackedTensorAccessor32<c10::Half,3,at::RestrictPtrTraits> y0,const at::PackedTensorAccessor32<c10::Half,3,at::RestrictPtrTraits> y1,const at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> rows,const at::PackedTensorAccessor32<int,1,at::RestrictPtrTraits> cols,const int nnz,const int E,F f)1068 __global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(
1069 at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits> values,
1070 const at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits>
1071 x_values,
1072 const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y0,
1073 const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y1,
1074 const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
1075 const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
1076 const int nnz,
1077 const int E,
1078 F f) {
1079 int values_row = threadIdx.y + blockIdx.y * blockDim.y;
1080 if (values_row >= nnz)
1081 return;
1082 for (int real_row = values_row; real_row < nnz;
1083 real_row += blockDim.y * gridDim.y) {
1084 int dense_row = rows[real_row];
1085 int dense_col = cols[real_row];
1086 __half* values_ptr = reinterpret_cast<__half*>(&values[real_row][0]);
1087 const __half* x_ptr =
1088 reinterpret_cast<const __half*>(&x_values[real_row][0]);
1089 const __half* y0_ptr =
1090 reinterpret_cast<const __half*>(&y0[dense_row][dense_col][0]);
1091 const __half* y1_ptr =
1092 reinterpret_cast<const __half*>(&y1[dense_row][dense_col][0]);
1093 if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) &&
1094 (dense_col < y1.size(1)) && (dense_row < y1.size(0)) &&
1095 (dense_col >= 0) && (dense_row >= 0)) {
1096 for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) {
1097 VecType128 v_x, v_out, v_y0, v_y1;
1098 v_x.data.mask =
1099 (reinterpret_cast<const VecType128::TType*>(x_ptr))[tid];
1100 v_y0.data.mask =
1101 (reinterpret_cast<const VecType128::TType*>(y0_ptr))[tid];
1102 v_y1.data.mask =
1103 (reinterpret_cast<const VecType128::TType*>(y1_ptr))[tid];
1104 f128(v_out, v_x, v_y0, v_y1, f);
1105 (reinterpret_cast<VecType128::TType*>(values_ptr))[tid] =
1106 v_out.data.mask;
1107 }
1108 for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4;
1109 tid += blockDim.x) {
1110 VecType64 v_x, v_out, v_y0, v_y1;
1111 v_x.data.mask = (reinterpret_cast<const VecType64::TType*>(x_ptr))[tid];
1112 v_y0.data.mask =
1113 (reinterpret_cast<const VecType64::TType*>(y0_ptr))[tid];
1114 v_y1.data.mask =
1115 (reinterpret_cast<const VecType64::TType*>(y1_ptr))[tid];
1116 f64(v_out, v_x, v_y0, v_y1, f);
1117 (reinterpret_cast<VecType64::TType*>(values_ptr))[tid] =
1118 v_out.data.mask;
1119 }
1120 for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2;
1121 tid += blockDim.x) {
1122 VecType32 v_x, v_out, v_y0, v_y1;
1123 v_x.data.mask = (reinterpret_cast<const VecType32::TType*>(x_ptr))[tid];
1124 v_y0.data.mask =
1125 (reinterpret_cast<const VecType32::TType*>(y0_ptr))[tid];
1126 v_y1.data.mask =
1127 (reinterpret_cast<const VecType32::TType*>(y1_ptr))[tid];
1128 f32(v_out, v_x, v_y0, v_y1, f);
1129 (reinterpret_cast<VecType32::TType*>(values_ptr))[tid] =
1130 v_out.data.mask;
1131 }
1132 for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) {
1133 auto v_x = static_cast<__half>(x_ptr[tid]);
1134 auto v_y0 = static_cast<__half>(y0_ptr[tid]);
1135 auto v_y1 = static_cast<__half>(y1_ptr[tid]);
1136 values_ptr[tid] = f(v_x, v_y0, v_y1);
1137 }
1138 } else {
1139 for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) {
1140 VecType128 v_x, v_out, v_y0, v_y1;
1141 v_x.data.mask =
1142 (reinterpret_cast<const VecType128::TType*>(x_ptr))[tid];
1143 f128(v_out, v_x, v_y0, v_y1, f);
1144 (reinterpret_cast<VecType128::TType*>(values_ptr))[tid] =
1145 v_out.data.mask;
1146 }
1147 for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4;
1148 tid += blockDim.x) {
1149 VecType64 v_x, v_out, v_y0, v_y1;
1150 v_x.data.mask = (reinterpret_cast<const VecType64::TType*>(x_ptr))[tid];
1151 f64(v_out, v_x, v_y0, v_y1, f);
1152 (reinterpret_cast<VecType64::TType*>(values_ptr))[tid] =
1153 v_out.data.mask;
1154 }
1155 for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2;
1156 tid += blockDim.x) {
1157 VecType32 v_x, v_out, v_y0, v_y1;
1158 v_x.data.mask = (reinterpret_cast<const VecType32::TType*>(x_ptr))[tid];
1159 f32(v_out, v_x, v_y0, v_y1, f);
1160 (reinterpret_cast<VecType32::TType*>(values_ptr))[tid] =
1161 v_out.data.mask;
1162 }
1163 for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) {
1164 auto v_x = static_cast<__half>(x_ptr[tid]);
1165 values_ptr[tid] = f(v_x, __half{}, __half{});
1166 }
1167 }
1168 }
1169 }
1170
1171 // Check to see if the inputs to the op are amenable to the fast path
jagged_dense_dense_elementwise_jagged_output_matches_opt(const int & num_jagged_dim,const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y_0_reshaped,const Tensor & y_1_reshaped,const Tensor & output_values)1172 inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
1173 const int& num_jagged_dim,
1174 const Tensor& x_values,
1175 const std::vector<Tensor>& x_offsets,
1176 const Tensor& y_0_reshaped,
1177 const Tensor& y_1_reshaped,
1178 const Tensor& output_values) {
1179 bool matches = true;
1180 matches &= (num_jagged_dim == 1);
1181
1182 // Unit stride embedding dim
1183 matches &= (x_values.stride(-1) == 1);
1184 matches &= (output_values.stride(-1) == 1);
1185 matches &= (y_0_reshaped.stride(-1) == 1);
1186 matches &= (y_1_reshaped.stride(-1) == 1);
1187
1188 // Each row is aligned to 128-bit
1189 matches &= (x_values.stride(-2) % 8 == 0);
1190 matches &= (output_values.stride(-2) % 8 == 0);
1191 matches &= (y_0_reshaped.stride(-2) % 8 == 0);
1192 matches &= (y_1_reshaped.stride(-2) % 8 == 0);
1193
1194 // Base addresses aligned to 128-bit
1195 matches &= (reinterpret_cast<uint64_t>(x_values.data_ptr()) % 16 == 0);
1196 matches &= (reinterpret_cast<uint64_t>(output_values.data_ptr()) % 16 == 0);
1197 matches &= (reinterpret_cast<uint64_t>(y_0_reshaped.data_ptr()) % 16 == 0);
1198 matches &= (reinterpret_cast<uint64_t>(y_1_reshaped.data_ptr()) % 16 == 0);
1199
1200 // Rows and col fit into int32_t
1201 matches &= (y_0_reshaped.size(0) < INT_MAX);
1202 matches &= (y_0_reshaped.size(1) < INT_MAX);
1203
1204 int max_shared_bytes;
1205 #ifndef USE_ROCM
1206 C10_CUDA_CHECK(cudaDeviceGetAttribute(
1207 &max_shared_bytes,
1208 cudaDevAttrMaxSharedMemoryPerBlockOptin,
1209 y_0_reshaped.get_device()));
1210 #else
1211 // MI100 has 64 KB local memory (shared memory) per workgroup
1212 max_shared_bytes = 64 << 10;
1213 #endif
1214 int shared_kb = max_shared_bytes >> 10;
1215 #ifndef USE_ROCM
1216 // Use 2/3 of the available GPU shared mem; leave rooms for L1$.
1217 int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
1218 TORCH_CHECK(used_shared_kb > 0);
1219 #else
1220 // MI100 has independent shared mem and L1
1221 int used_shared_kb = shared_kb;
1222 #endif
1223 int used_shared_bytes = used_shared_kb << 10;
1224 AT_DISPATCH_INDEX_TYPES(
1225 x_offsets[0].scalar_type(), "check_shared_memory", [&] {
1226 auto B = y_0_reshaped.size(0);
1227 // the default shared memory on V100/A100/H100 is 48 KB from
1228 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x
1229 if ((B + 1) * sizeof(index_t) >= used_shared_bytes) {
1230 matches = false;
1231 }
1232 });
1233 return matches;
1234 }
1235
1236 #define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
1237 { \
1238 auto [threads, blocks, jagged_dims_tensor] = \
1239 check_shape_and_partition_(x_values, x_offsets, y); \
1240 blocks.x = div_round_up(x_values.size(0), threads.y); \
1241 std::vector<Tensor> x_offsets_contig; \
1242 x_offsets_contig.resize(num_jagged_dim); \
1243 StackArray<index_t*> x_offset_ptrs; \
1244 x_offset_ptrs.ndim = num_jagged_dim; \
1245 StackArray<int64_t> x_offset_sizes; \
1246 x_offset_sizes.ndim = num_jagged_dim; \
1247 for (int d = 0; d < num_jagged_dim; ++d) { \
1248 x_offsets_contig[d] = x_offsets[d].contiguous(); \
1249 x_offset_ptrs.vals[d] = \
1250 x_offsets_contig[d].template data_ptr<index_t>(); \
1251 x_offset_sizes.vals[d] = x_offsets[d].numel(); \
1252 } \
1253 jagged_dense_dense_elementwise_jagged_output_kernel_< \
1254 NUM_JAGGED_DIM, \
1255 index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
1256 x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
1257 x_offset_ptrs, \
1258 x_offset_sizes, \
1259 y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
1260 y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
1261 output_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
1262 jagged_dims_tensor, \
1263 [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \
1264 -> scalar_t { return f(x, y); }); \
1265 }
1266
calc_used_shared_bytes(const int device)1267 inline int calc_used_shared_bytes(const int device) {
1268 int max_shared_bytes;
1269 #ifndef USE_ROCM
1270 C10_CUDA_CHECK(cudaDeviceGetAttribute(
1271 &max_shared_bytes,
1272 cudaDevAttrMaxSharedMemoryPerBlockOptin,
1273 device));
1274 #else
1275 // MI100 has 64 KB local memory (shared memory) per workgroup
1276 max_shared_bytes = 64 << 10;
1277 #endif
1278 int shared_kb = max_shared_bytes >> 10;
1279 #ifndef USE_ROCM
1280 // Use 2/3 of the available GPU shared mem; leave rooms for L1$.
1281 int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
1282 TORCH_CHECK(used_shared_kb > 0);
1283 #else
1284 // MI100 has independent shared mem and L1
1285 int used_shared_kb = shared_kb;
1286 #endif
1287 int used_shared_bytes = used_shared_kb << 10;
1288 return used_shared_bytes;
1289 }
1290
1291 template <typename index_t>
set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes)1292 inline void set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes) {
1293 #ifndef USE_ROCM
1294 C10_CUDA_CHECK(cudaFuncSetAttribute(
1295 jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
1296 index_t>,
1297 cudaFuncAttributeMaxDynamicSharedMemorySize,
1298 used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB
1299 #endif
1300 }
1301
1302 ///@addtogroup jagged-tensor-ops-cuda
1303 template <typename scalar_t, typename F>
jagged_dense_elementwise_jagged_output_opt_(const Tensor & x_values,const std::vector<Tensor> & x_offsets,const Tensor & y,const Tensor & output_values,F f)1304 void jagged_dense_elementwise_jagged_output_opt_(
1305 const Tensor& x_values,
1306 const std::vector<Tensor>& x_offsets,
1307 const Tensor& y,
1308 const Tensor& output_values,
1309 F f) {
1310 TENSOR_ON_CUDA_GPU(x_values);
1311 for (auto& x_offset : x_offsets) {
1312 TENSOR_ON_CUDA_GPU(x_offset);
1313 }
1314
1315 const int num_jagged_dim = y.dim() - 2;
1316 TORCH_CHECK(
1317 x_offsets.size() == static_cast<size_t>(num_jagged_dim),
1318 "x_offsets.size(), ",
1319 x_offsets.size(),
1320 " != num_jagged_dim, ",
1321 num_jagged_dim);
1322
1323 if (y.numel() == 0 || x_values.numel() == 0) {
1324 return;
1325 }
1326
1327 // Canonicalize y to 3D, collapsing jagged dimensions.
1328 const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
1329 if (jagged_dense_dense_elementwise_jagged_output_matches_opt(
1330 num_jagged_dim,
1331 x_values,
1332 x_offsets,
1333 y_reshaped,
1334 y_reshaped,
1335 output_values)) {
1336 AT_DISPATCH_INDEX_TYPES(
1337 x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] {
1338 auto nnz = output_values.size(0);
1339 auto B = y_reshaped.size(0);
1340 auto E = y_reshaped.size(2);
1341 Tensor t_rows_after_bs = at::empty(
1342 {nnz},
1343 at::TensorOptions().dtype(at::kInt).device(
1344 at::kCUDA, at::cuda::current_device()));
1345 Tensor t_cols_after_bs = at::empty(
1346 {nnz},
1347 at::TensorOptions().dtype(at::kInt).device(
1348 at::kCUDA, at::cuda::current_device()));
1349
1350 // Binary search
1351 size_t dynamic_smem_size = (B + 1) * sizeof(index_t);
1352 auto cur_max_shared_bytes =
1353 at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
1354 if (dynamic_smem_size > cur_max_shared_bytes) {
1355 int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device());
1356 set_max_dynamic_shared_mem_size_for_opt_search_kernel<index_t>(used_shared_bytes);
1357 C10_CUDA_KERNEL_LAUNCH_CHECK();
1358 TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
1359 }
1360 dim3 threads_bs = dim3(1024, 1, 1);
1361 dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
1362 jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
1363 index_t>
1364 <<<blocks_bs,
1365 threads_bs,
1366 dynamic_smem_size,
1367 at::cuda::getCurrentCUDAStream()>>>(
1368 x_offsets[0]
1369 .packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
1370 t_rows_after_bs
1371 .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1372 t_cols_after_bs
1373 .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1374 nnz,
1375 B);
1376 C10_CUDA_KERNEL_LAUNCH_CHECK();
1377 // Gather kernel
1378 dim3 threads = dim3(16, 16, 1);
1379 dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1);
1380 if (blocks.y > 65535) {
1381 blocks.y = 65535;
1382 }
1383 jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
1384 index_t>
1385 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
1386 output_values
1387 .packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
1388 x_values
1389 .packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
1390 y_reshaped
1391 .packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
1392 y_reshaped
1393 .packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
1394 t_rows_after_bs
1395 .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1396 t_cols_after_bs
1397 .packed_accessor32<int, 1, at::RestrictPtrTraits>(),
1398 nnz,
1399 E,
1400 [f] __device__(__half x, __half y0, __half) -> __half {
1401 // NB: added the static_casts here
1402 return static_cast<__half>(
1403 f(static_cast<scalar_t>(x), static_cast<scalar_t>(y0))
1404 );
1405 });
1406 C10_CUDA_KERNEL_LAUNCH_CHECK();
1407 }); // AT_DISPATCH
1408 } else {
1409 JAGGED_TENSOR_DISPATCH_DIMS();
1410 C10_CUDA_KERNEL_LAUNCH_CHECK();
1411 }
1412 }
1413
_fbgemm_jagged_to_padded_dense_forward(const Tensor & values,TensorList offsets,c10::IntArrayRef max_lengths,const double padding_value)1414 at::Tensor _fbgemm_jagged_to_padded_dense_forward(
1415 const Tensor& values,
1416 TensorList offsets,
1417 c10::IntArrayRef max_lengths,
1418 const double padding_value) {
1419 const size_t num_jagged_dim = offsets.size();
1420 TORCH_CHECK(
1421 max_lengths.size() == num_jagged_dim,
1422 "max_lengths.size(), ",
1423 max_lengths.size(),
1424 " != num_jagged_dim, ",
1425 num_jagged_dim);
1426 at::cuda::OptionalCUDAGuard device_guard;
1427 device_guard.set_index(values.get_device());
1428
1429 const Tensor values_canonicalized = values.view(
1430 {values.size(0),
1431 std::accumulate(
1432 values.sizes().begin() + 1,
1433 values.sizes().end(),
1434 1,
1435 std::multiplies<size_t>())});
1436 at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)});
1437 padded_values_shape.insert(
1438 padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
1439
1440 // Canonicalize padded_values by unsqueeze the last dim if the inner dense
1441 // dimension is 1 and folded.
1442 const bool D_folded = values.dim() == 1;
1443 if (!D_folded) {
1444 padded_values_shape.push_back(values.size(-1));
1445 }
1446 Tensor padded_values =
1447 at::empty_symint(padded_values_shape, values.options());
1448 Tensor padded_values_view =
1449 D_folded ? padded_values.unsqueeze(-1) : padded_values;
1450
1451 AT_DISPATCH_ALL_TYPES_AND2(
1452 at::ScalarType::Half,
1453 at::ScalarType::BFloat16,
1454 values.scalar_type(),
1455 "jagged_to_padded_dense",
1456 [&] {
1457 jagged_dense_elementwise_dense_output_<scalar_t>(
1458 values_canonicalized,
1459 offsets.vec(),
1460 padded_values_view, // dummy not used in the lambda function
1461 padded_values_view,
1462 [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
1463 return x;
1464 },
1465 static_cast<scalar_t>(padding_value));
1466 });
1467
1468 return padded_values;
1469 }
1470
1471 #define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \
1472 AT_DISPATCH_CASE(TYPE, [&] { \
1473 jagged_dense_elementwise_jagged_output_opt_<scalar_t>( \
1474 values, \
1475 offsets.vec(), \
1476 dense, \
1477 output, \
1478 [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \
1479 return y; \
1480 }); \
1481 })
1482
_fbgemm_dense_to_jagged_forward_symint(const Tensor & dense,TensorList offsets,std::optional<at::SymInt> total_L)1483 Tensor _fbgemm_dense_to_jagged_forward_symint(
1484 const Tensor& dense,
1485 TensorList offsets,
1486 std::optional<at::SymInt> total_L) {
1487 // D is the embedding dimension
1488 auto D = dense.size(-1);
1489
1490 // If total_L is not given then compute it
1491 at::SymInt total_L_computed;
1492 if (total_L.has_value()) {
1493 total_L_computed = total_L.value();
1494 } else {
1495 total_L_computed = (int64_t)offsets.back().max().item<int64_t>();
1496 }
1497 auto values = at::empty_symint({total_L_computed, D}, dense.options());
1498 auto output = at::empty_like(values);
1499
1500 at::cuda::OptionalCUDAGuard device_guard;
1501 device_guard.set_index(dense.get_device());
1502
1503 // clang-format off
1504 AT_DISPATCH_SWITCH(
1505 values.scalar_type(),
1506 "dense_to_jagged_gpu_op_forward",
1507 DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half)
1508 // NB: removed this to build
1509 // DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int)
1510 AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
1511 at::ScalarType::Long,
1512 at::ScalarType::BFloat16,
1513 [&] {
1514 jagged_dense_elementwise_jagged_output_<scalar_t>(
1515 values,
1516 offsets.vec(),
1517 dense,
1518 output,
1519 [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
1520 return y;
1521 }); // device lambda
1522 } // lambda
1523 ) // CASE_FLOATING_TYPES_AND
1524 ); // SWITCH
1525 // clang-format on
1526
1527 #undef DISPATCH_DENSE_TO_JAGGED_CASE
1528
1529 return output;
1530 }
1531
1532 } // namespace native
1533 } // namespace at
1534