xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/cuda/SortingCommon.cuh>
4 #include <ATen/cuda/cub_definitions.cuh>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty_like.h>
10 #endif
11 
12 #include <ATen/cuda/ThrustAllocator.h>
13 #include <thrust/device_ptr.h>
14 #include <thrust/execution_policy.h>
15 #include <thrust/sort.h>
16 #include <thrust/unique.h>
17 #include <thrust/device_ptr.h>
18 #include <thrust/iterator/constant_iterator.h>
19 
20 namespace at::native {
21 
index_put_with_sort_kernel_thrust_helper(Tensor & linearIndex,Tensor & orig_indices,Tensor & sorted_indices,int64_t num_indices)22 void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
23   sorted_indices.copy_(linearIndex);
24   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25   at::cuda::ThrustAllocator allocator;
26   auto policy = thrust::cuda::par(allocator).on(stream);
27 
28   using device_ptr = thrust::device_ptr<int64_t>;
29 
30   // Fill sortedOrigIndices with sequential indices
31   const auto count_iter = thrust::counting_iterator<int64_t>(0);
32   auto orig_data = device_ptr(orig_indices.mutable_data_ptr<int64_t>());
33   thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
34 
35   // Sort the inputs into sorted with the corresponding indices; we
36   // don't need a stable or multidimensional sort, so just use Thrust
37   // directly
38   // Sort; a stable sort is not required
39   // NB - not passing comparator causes thrust to use radix sort, and it hurts perf A LOT, at least for medium (few K) sized indices
40   auto sorted_data = device_ptr(sorted_indices.mutable_data_ptr<int64_t>());
41   thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
42 }
43 
44 #if !CUB_SUPPORTS_SCAN_BY_KEY()
45 
46 template<typename index_t>
embedding_dense_backward_cuda_scan(Tensor & sorted_indices,Tensor & count)47 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
48   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49   at::cuda::ThrustAllocator allocator;
50   auto policy = thrust::cuda::par(allocator).on(stream);
51 
52   auto num_indices = count.numel();
53 
54   // Compute an increasing sequence per unique item in sortedIndices:
55   // sorted: 2 5 5 5 7 7 8 9 9
56   //  count: 1 1 2 3 1 2 1 1 2
57   auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
58   auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
59   thrust::inclusive_scan_by_key(
60     policy,
61     sorted_data,
62     sorted_data + num_indices,
63     thrust::make_constant_iterator(1),
64     count_data
65   );
66 
67   // Take the maximum of each count per unique key in reverse:
68   // sorted: 2 5 5 5 7 7 8 9 9
69   //  count: 1 3 3 3 2 2 1 2 2
70   thrust::inclusive_scan_by_key(
71     policy,
72     thrust::make_reverse_iterator(sorted_data + num_indices),
73     thrust::make_reverse_iterator(sorted_data),
74     thrust::make_reverse_iterator(count_data + num_indices),
75     thrust::make_reverse_iterator(count_data + num_indices),
76     thrust::equal_to<index_t>(),
77     thrust::maximum<index_t>()
78   );
79 }
80 
81 template
82 void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
83 template
84 void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
85 
86 #endif
87 
88 template<typename index_t>
embedding_backward_cuda_kernel_unique_by_key(const Tensor & sorted_indices,Tensor & segment_offsets)89 int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
90   auto stream = at::cuda::getCurrentCUDAStream();
91   at::cuda::ThrustAllocator allocator;
92   auto policy = thrust::cuda::par(allocator).on(stream);
93   const ptrdiff_t numel = sorted_indices.numel();
94   auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
95   auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
96   auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
97   auto ends = thrust::unique_by_key_copy(
98           policy,
99           sorted_indices_dev,
100           sorted_indices_dev + numel,
101           thrust::make_counting_iterator(0),
102           dummy_dev,
103           thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
104   return thrust::get<0>(ends) - dummy_dev;
105 }
106 
107 template
108 int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
109 template
110 int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
111 
112 } // namespace at::native
113