1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/cuda/SortingCommon.cuh>
4 #include <ATen/cuda/cub_definitions.cuh>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty_like.h>
10 #endif
11
12 #include <ATen/cuda/ThrustAllocator.h>
13 #include <thrust/device_ptr.h>
14 #include <thrust/execution_policy.h>
15 #include <thrust/sort.h>
16 #include <thrust/unique.h>
17 #include <thrust/device_ptr.h>
18 #include <thrust/iterator/constant_iterator.h>
19
20 namespace at::native {
21
index_put_with_sort_kernel_thrust_helper(Tensor & linearIndex,Tensor & orig_indices,Tensor & sorted_indices,int64_t num_indices)22 void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
23 sorted_indices.copy_(linearIndex);
24 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25 at::cuda::ThrustAllocator allocator;
26 auto policy = thrust::cuda::par(allocator).on(stream);
27
28 using device_ptr = thrust::device_ptr<int64_t>;
29
30 // Fill sortedOrigIndices with sequential indices
31 const auto count_iter = thrust::counting_iterator<int64_t>(0);
32 auto orig_data = device_ptr(orig_indices.mutable_data_ptr<int64_t>());
33 thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
34
35 // Sort the inputs into sorted with the corresponding indices; we
36 // don't need a stable or multidimensional sort, so just use Thrust
37 // directly
38 // Sort; a stable sort is not required
39 // NB - not passing comparator causes thrust to use radix sort, and it hurts perf A LOT, at least for medium (few K) sized indices
40 auto sorted_data = device_ptr(sorted_indices.mutable_data_ptr<int64_t>());
41 thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
42 }
43
44 #if !CUB_SUPPORTS_SCAN_BY_KEY()
45
46 template<typename index_t>
embedding_dense_backward_cuda_scan(Tensor & sorted_indices,Tensor & count)47 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
48 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49 at::cuda::ThrustAllocator allocator;
50 auto policy = thrust::cuda::par(allocator).on(stream);
51
52 auto num_indices = count.numel();
53
54 // Compute an increasing sequence per unique item in sortedIndices:
55 // sorted: 2 5 5 5 7 7 8 9 9
56 // count: 1 1 2 3 1 2 1 1 2
57 auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
58 auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
59 thrust::inclusive_scan_by_key(
60 policy,
61 sorted_data,
62 sorted_data + num_indices,
63 thrust::make_constant_iterator(1),
64 count_data
65 );
66
67 // Take the maximum of each count per unique key in reverse:
68 // sorted: 2 5 5 5 7 7 8 9 9
69 // count: 1 3 3 3 2 2 1 2 2
70 thrust::inclusive_scan_by_key(
71 policy,
72 thrust::make_reverse_iterator(sorted_data + num_indices),
73 thrust::make_reverse_iterator(sorted_data),
74 thrust::make_reverse_iterator(count_data + num_indices),
75 thrust::make_reverse_iterator(count_data + num_indices),
76 thrust::equal_to<index_t>(),
77 thrust::maximum<index_t>()
78 );
79 }
80
81 template
82 void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
83 template
84 void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
85
86 #endif
87
88 template<typename index_t>
embedding_backward_cuda_kernel_unique_by_key(const Tensor & sorted_indices,Tensor & segment_offsets)89 int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
90 auto stream = at::cuda::getCurrentCUDAStream();
91 at::cuda::ThrustAllocator allocator;
92 auto policy = thrust::cuda::par(allocator).on(stream);
93 const ptrdiff_t numel = sorted_indices.numel();
94 auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
95 auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
96 auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
97 auto ends = thrust::unique_by_key_copy(
98 policy,
99 sorted_indices_dev,
100 sorted_indices_dev + numel,
101 thrust::make_counting_iterator(0),
102 dummy_dev,
103 thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
104 return thrust::get<0>(ends) - dummy_dev;
105 }
106
107 template
108 int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
109 template
110 int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
111
112 } // namespace at::native
113