xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The algorithm for dynamic partition has the following steps:
17 // 1. Let N be the size of partitions. We initialize a new vector indices_in
18 //    with the values 0, 1, 2, ..., N-1.
19 // 2. We apply gpuprim::DeviceRadixSort::SortPairs to the key - value pairs
20 //    given by partitions and indices_in. This will result in two new vectors
21 //    partitions_out and indices_out, with partitions_out sorted.
22 // 3. The first dimension of outputs[i] is equal to the number of i-values in
23 //    partitions_out. We determine it in two steps:
24 //    - apply gpuprim::DeviceReduce::ReduceByKey to count how many times each
25 //      value appears in partitions_out,
26 //    - move the results to partition_count. This handles missing values
27 //      (corresponding to empty parts).
28 // 4. Because partition_count is on the GPU, we bring it asynchronously to
29 //    the CPU. Then we can allocate the output tensors.
30 // 5. Finally, we use indices_out and the gather functor to collect the output.
31 //    This works, because for each interval of i-values, indices_out points
32 //    to the slices which should form output[i].
33 
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 
36 #define EIGEN_USE_GPU
37 
38 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
39 #include "tensorflow/core/framework/bounds_check.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/register_types.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor_reference.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/kernels/fill_functor.h"
46 #include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
47 #include "tensorflow/core/kernels/gpu_prim.h"
48 #include "tensorflow/core/util/gpu_kernel_helper.h"
49 #include "tensorflow/core/util/transform_output_iterator.h"
50 
51 #if GOOGLE_CUDA
52 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
53 using stream_executor::cuda::ScopedActivateExecutorContext;
54 #elif TENSORFLOW_USE_ROCM
55 #include "tensorflow/core/platform/rocm.h"
56 using stream_executor::rocm::ScopedActivateExecutorContext;
57 #endif  // GOOGLE_CUDA
58 
59 namespace tensorflow {
60 
61 typedef Eigen::GpuDevice GPUDevice;
62 
63 namespace {
64 
65 template <typename T>
RangeInitKernel(const T start,const T delta,const int32 size,T * out)66 __global__ void RangeInitKernel(const T start, const T delta, const int32 size,
67                                 T* out) {
68   GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
69 }
70 
MoveValuesKernel(const int32 * keys,const int32 * values,const int32 * size,int32 out_size,int32 * out)71 __global__ void MoveValuesKernel(const int32* keys, const int32* values,
72                                  const int32* size, int32 out_size,
73                                  int32* out) {
74   int32 N = min(ldg(size), out_size);
75   GPU_1D_KERNEL_LOOP(i, N) {
76     int32 key = ldg(keys + i);
77     int32 value = ldg(values + i);
78     if (FastBoundsCheck(key, out_size)) out[key] = value;
79   }
80 }
81 
82 // Initialize out with range start, start + delta, start + 2 * delta, ...
83 // This is needed because tf.range has no GPU implementation.
84 template <typename T>
RangeInit(const GPUDevice & d,const T start,const T delta,const int32 size,typename TTypes<T>::Flat out)85 void RangeInit(const GPUDevice& d, const T start, const T delta,
86                const int32 size, typename TTypes<T>::Flat out) {
87   GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
88   TF_CHECK_OK(GpuLaunchKernel(RangeInitKernel<T>, config.block_count,
89                               config.thread_per_block, 0, d.stream(), start,
90                               delta, size, out.data()));
91 }
92 
93 // Given *num_runs pairs (key, value), this function moves the value
94 // corresponding to key i at position i in the array out.
MoveValues(const GPUDevice & d,int32 * keys,int32 * values,int32 * num_runs,int32 out_size,int32 * out)95 void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
96                 int32 out_size, int32* out) {
97   // Because num_runs is located on the GPU, we can not access it directly.
98   // So we launch the kernel with size = out_size.
99   // This is valid for correct inputs, because then out_size >= *num_runs.
100   // For wrong inputs, we may have out_size < *num_runs. In this case we will
101   // only handle the first out_size values.
102   GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
103   TF_CHECK_OK(GpuLaunchKernel(MoveValuesKernel, config.block_count,
104                               config.thread_per_block, 0, d.stream(), keys,
105                               values, num_runs, out_size, out));
106 }
107 
108 struct IdentityOp {
operator ()tensorflow::__anon4a1128070111::IdentityOp109   __device__ int32 __forceinline__ operator()(const int32& a) const {
110     return a;
111   }
112 };
113 
114 // Define an output iterator that only allows assignment to
115 // positions between [base, base + limit).
116 class BoundedOutputIterator
117     : public TransformOutputIterator<int32, int32, IdentityOp> {
118  private:
119   int32 limit;
120   int32* base;
121 
122   struct BoundedReference : Reference {
123     int32 limit;
124     int32* base;
125     // Constructor
126     __host__ __device__ __forceinline__
BoundedReferencetensorflow::__anon4a1128070111::BoundedOutputIterator::BoundedReference127     BoundedReference(int32* __restrict__ ptr, int32* __restrict__ base,
128                      IdentityOp op, int32 limit)
129         : Reference(ptr, op), limit(limit), base(base) {}
130 
131     // Assignment
operator =tensorflow::__anon4a1128070111::BoundedOutputIterator::BoundedReference132     __host__ __device__ __forceinline__ int32 operator=(int32 val) {
133       if (ptr - base < limit && ptr - base >= 0) *ptr = val;
134       return val;
135     }
136   };
137 
138  public:
139   typedef BoundedOutputIterator self_type;
140   typedef BoundedReference reference;
141 
142   __host__ __device__ __forceinline__
BoundedOutputIterator(int32 * __restrict__ ptr,IdentityOp op,int32 size)143   BoundedOutputIterator(int32* __restrict__ ptr, IdentityOp op, int32 size)
144       : TransformOutputIterator(ptr, op), limit(size), base(ptr) {}
145 
146   __host__ __device__ __forceinline__
BoundedOutputIterator(int32 * __restrict__ ptr,int32 * __restrict__ base,IdentityOp op,int32 size)147   BoundedOutputIterator(int32* __restrict__ ptr, int32* __restrict__ base,
148                         IdentityOp op, int32 size)
149       : TransformOutputIterator(ptr, op), limit(size), base(base) {}
150 
151   // Indirection
operator *() const152   __host__ __device__ __forceinline__ reference operator*() const {
153     return BoundedReference(ptr, base, conversion_op, limit);
154   }
155 
156   // Array subscript
operator [](int32 n) const157   __host__ __device__ __forceinline__ reference operator[](int32 n) const {
158     return BoundedReference(ptr + n, base, conversion_op, limit);
159   }
160 
161   // Addition
operator +(int32 n) const162   __host__ __device__ __forceinline__ self_type operator+(int32 n) const {
163     self_type retval(ptr + n, base, conversion_op, limit);
164     return retval;
165   }
166 
167   // Subtraction
operator -(int32 n) const168   __host__ __device__ __forceinline__ self_type operator-(int32 n) const {
169     self_type retval(ptr - n, base, conversion_op, limit);
170     return retval;
171   }
172 };
173 
174 }  // namespace
175 
176 // The current implementation has memory cost on GPU
177 // I + P + max(3N + R + P, O + N), where:
178 // I - the size of the input
179 // N - the size of the partitions tensor
180 // R - the temporary storage used by gpuprim::RadixSort, about 2N
181 // P - the number of partitions
182 // O - the size of the output
183 // So roughly the cost is I + P + max(5N, O + N).
184 template <typename T>
185 class DynamicPartitionOpGPU : public AsyncOpKernel {
186  public:
DynamicPartitionOpGPU(OpKernelConstruction * c)187   explicit DynamicPartitionOpGPU(OpKernelConstruction* c) : AsyncOpKernel(c) {
188     OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
189     OP_REQUIRES(c, num_partitions_ >= 1,
190                 errors::InvalidArgument("num_partitions must be at least 1"));
191   }
192 
AllocateTempSpace(OpKernelContext * c,int32 N,Tensor * indices_in,Tensor * partitions_out,Tensor * indices_out,DoneCallback done)193   void AllocateTempSpace(OpKernelContext* c, int32 N, Tensor* indices_in,
194                          Tensor* partitions_out, Tensor* indices_out,
195                          DoneCallback done) {
196     int32 M = std::max(N, num_partitions_);
197     // indices_in will be made slightly larger to accommodate
198     // later computations.
199     OP_REQUIRES_OK_ASYNC(
200         c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done);
201     OP_REQUIRES_OK_ASYNC(
202         c, c->allocate_temp(DT_INT32, TensorShape({N}), partitions_out), done);
203     OP_REQUIRES_OK_ASYNC(
204         c, c->allocate_temp(DT_INT32, TensorShape({N}), indices_out), done);
205   }
206 
AllocateOutputs(OpKernelContext * c,const Tensor * data,const Tensor * partitions,const Tensor * partition_count,OpOutputList * Tout,DoneCallback done)207   void AllocateOutputs(OpKernelContext* c, const Tensor* data,
208                        const Tensor* partitions, const Tensor* partition_count,
209                        OpOutputList* Tout, DoneCallback done) {
210     auto e_part_count = partition_count->flat<int32>();
211     // Allocate output tensors of the right size
212     OP_REQUIRES_OK_ASYNC(c, c->output_list("outputs", Tout), done);
213     for (int p = 0; p < num_partitions_; p++) {
214       TensorShape shape;
215       shape.AddDim(e_part_count(p));
216       for (int i = partitions->dims(); i < data->dims(); i++) {
217         shape.AddDim(data->dim_size(i));
218       }
219       Tensor* out;
220       OP_REQUIRES_OK_ASYNC(c, Tout->allocate(p, shape, &out), done);
221     }
222   }
223 
ComputeAsync(OpKernelContext * c,DoneCallback done)224   void ComputeAsync(OpKernelContext* c, DoneCallback done) {
225     const Tensor& data = c->input(0);
226     const Tensor& partitions = c->input(1);
227 
228     OP_REQUIRES_ASYNC(
229         c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()),
230         errors::InvalidArgument(
231             "data.shape must start with partitions.shape, ",
232             "got data.shape = ", data.shape().DebugString(),
233             ", partitions.shape = ", partitions.shape().DebugString()),
234         done);
235 
236     Tensor partition_count;
237 
238     // We must handle the case of empty partitions separately,
239     // because kernels don't work with 0-sized tensors.
240     if (partitions.NumElements() == 0) {
241       AllocatorAttributes alloc_attr;
242       alloc_attr.set_on_host(true);
243       OP_REQUIRES_OK_ASYNC(
244           c,
245           c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
246                            &partition_count, alloc_attr),
247           done);
248       auto e_part_count = partition_count.flat<int32>();
249       for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0;
250       OpOutputList outputs;
251       this->AllocateOutputs(c, &data, &partitions, &partition_count, &outputs,
252                             done);
253       if (c->status().ok()) done();
254       return;
255     }
256 
257     // Prepare for counting.
258     OP_REQUIRES_OK_ASYNC(
259         c,
260         c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
261                          &partition_count),
262         done);
263     Tensor indices_out;
264     // Count how many times each partition index occurs.
265     // Also sort the info in partitions and output it in indices_out,
266     // in preparation for the next step.
267     this->CountAndSortParts(c, &partitions, &partition_count, &indices_out,
268                             done);
269     if (!c->status().ok()) return;
270 
271     // In order to allocate the output tensor we have to move partition_count
272     // to CPU.
273     auto* stream = c->op_device_context()->stream();
274     OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."),
275                       done);
276     Tensor cpu_tensor;
277     AllocatorAttributes alloc_attr;
278     alloc_attr.set_on_host(true);
279     alloc_attr.set_gpu_compatible(true);
280     OP_REQUIRES_OK_ASYNC(
281         c,
282         c->allocate_temp(partition_count.dtype(), partition_count.shape(),
283                          &cpu_tensor, alloc_attr),
284         done);
285     se::DeviceMemoryBase wrapped(partition_count.flat<int32>().data(),
286                                  num_partitions_ * sizeof(int32));
287     const bool status =
288         stream
289             ->ThenMemcpy(cpu_tensor.flat<int32>().data(), wrapped,
290                          num_partitions_ * sizeof(int32))
291             .ok();
292     OP_REQUIRES_ASYNC(
293         c, status,
294         errors::Internal("Failed to launch copy from device to host."), done);
295 
296     // Keep a reference to partition_count so that the buffer
297     // is not deallocated at the end of the function, before
298     // memcpy is completed.
299     TensorReference partition_ref(partition_count);
300     auto wrapped_callback = [this, c, &data, &partitions, indices_out,
301                              partition_ref, cpu_tensor, done]() {
302       auto stream = c->op_device_context()->stream();
303       ScopedActivateExecutorContext scoped_activation{stream->parent()};
304 
305       OpOutputList outputs;
306       this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
307       if (!c->status().ok()) {
308         partition_ref.Unref();
309         return;
310       }
311       int32 N = partitions.NumElements();
312       int64 slice_size = data.NumElements() / N;
313       this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs);
314       partition_ref.Unref();
315       done();
316     };
317 
318     c->device()->tensorflow_accelerator_device_info()->event_mgr->ThenExecute(
319         stream, wrapped_callback);
320   }
321 
322  protected:
RadixSort(OpKernelContext * c,const Tensor * partitions,Tensor * indices_in,Tensor * partitions_out,Tensor * indices_out,DoneCallback done)323   void RadixSort(OpKernelContext* c, const Tensor* partitions,
324                  Tensor* indices_in, Tensor* partitions_out,
325                  Tensor* indices_out, DoneCallback done) {
326     int32 N = partitions->NumElements();
327     const GPUDevice& device = c->eigen_device<GPUDevice>();
328     const auto& cu_stream = GetGpuStream(c);
329 
330     // Initialize the indices_in tensor using the Range GPU kernel.
331     RangeInit(device, 0, 1, N, indices_in->flat<int32>());
332     // Obtain the pointers to inner buffers.
333     const int32* partitions_ptr = partitions->flat<int32>().data();
334     int32* partitions_out_ptr = partitions_out->flat<int32>().data();
335     int32* indices_in_ptr = indices_in->flat<int32>().data();
336     int32* indices_out_ptr = indices_out->flat<int32>().data();
337     // Determine temporary device storage requirements.
338     Tensor cub_temp_storage;
339     size_t temp_storage_bytes = 0;
340     gpuprim::DeviceRadixSort::SortPairs(
341         NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr,
342         indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream);
343     // Allocate temporary storage.
344     OP_REQUIRES_OK_ASYNC(
345         c,
346         c->allocate_temp(
347             DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
348             &cub_temp_storage),
349         done);
350     // Radix-sort the partition information.
351     gpuprim::DeviceRadixSort::SortPairs(
352         cub_temp_storage.flat<int8>().data(), temp_storage_bytes,
353         partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
354         0, sizeof(int32) * 8, cu_stream);
355   }  // At this point cub_temp_storage will be marked for deallocation.
356 
CountAndSortParts(OpKernelContext * c,const Tensor * partitions,Tensor * partition_count,Tensor * indices_out,DoneCallback done)357   void CountAndSortParts(OpKernelContext* c, const Tensor* partitions,
358                          Tensor* partition_count, Tensor* indices_out,
359                          DoneCallback done) {
360     const GPUDevice& device = c->eigen_device<GPUDevice>();
361     const auto& cu_stream = GetGpuStream(c);
362     int32 N = partitions->NumElements();
363     Tensor indices_in;
364     Tensor partitions_out;
365     Tensor aggregates_out;
366 
367     // Allocate memory for Radix-Sort.
368     this->AllocateTempSpace(c, N, &indices_in, &partitions_out, indices_out,
369                             done);
370     if (!c->status().ok()) return;
371     this->RadixSort(c, partitions, &indices_in, &partitions_out, indices_out,
372                     done);
373     if (!c->status().ok()) return;
374     // We will now apply a reduce operation to count how many times
375     // each index appears in partitions.
376 
377     // Zero-out the partition_count tensor.
378     functor::SetZeroFunctor<GPUDevice, int32> zero_functor;
379     zero_functor(device, partition_count->flat<int32>());
380     // Allocate memory for aggregates_out.
381     OP_REQUIRES_OK_ASYNC(
382         c,
383         c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
384                          &aggregates_out),
385         done);
386     // Obtain the pointers to inner buffers.
387     int32* keys_in_ptr = partitions_out.flat<int32>().data();
388     // Here we reuse the indices_in tensor for the unique keys output.
389     int32* unique_out_ptr = indices_in.flat<int32>().data();
390     int32* aggregates_out_ptr = aggregates_out.flat<int32>().data();
391     // We wrap the pointers in bounded output iterators to guard against
392     // wrong inputs (more than num_partitions distinct indices).
393     IdentityOp id_op;
394     BoundedOutputIterator unique_out_it(unique_out_ptr, id_op, num_partitions_);
395     BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op,
396                                             num_partitions_);
397 
398 #if GOOGLE_CUDA
399     cub::ConstantInputIterator<int32> values_in(1);
400 #elif TENSORFLOW_USE_ROCM
401     using ConstantInputIterator =
402         ::rocprim::constant_iterator<int32, ptrdiff_t>;
403     ConstantInputIterator values_in(1);
404 #endif
405     gpuprim::Sum reduction_op;
406 
407     // Allocate space on GPU for the number of runs. This is required by CUB.
408     Tensor num_runs;
409     OP_REQUIRES_OK_ASYNC(
410         c, c->allocate_temp(DT_INT32, TensorShape({1}), &num_runs), done);
411     int32* num_runs_ptr = num_runs.flat<int32>().data();
412 
413     // Determine temporary device storage requirements
414     Tensor cub_temp_storage;
415     size_t temp_storage_bytes = 0;
416     gpuprim::DeviceReduce::ReduceByKey(
417         NULL, temp_storage_bytes, keys_in_ptr, unique_out_it, values_in,
418         aggregates_out_it, num_runs_ptr, reduction_op, N, cu_stream);
419     // Allocate temporary storage.
420     OP_REQUIRES_OK_ASYNC(
421         c,
422         c->allocate_temp(
423             DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
424             &cub_temp_storage),
425         done);
426     // Run reduce-by-key. The effect is that we count how many times
427     // each index appears in partitions. The distinct indices are stored
428     // in unique_out, while the count is stored in aggregates_out.
429     // The total number of distinct indices is stored in num_runs.
430     gpuprim::DeviceReduce::ReduceByKey(
431         cub_temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in_ptr,
432         unique_out_it, values_in, aggregates_out_it, num_runs_ptr, reduction_op,
433         N, cu_stream);
434     // We are not done yet. unique_out only contains the indices that appeared
435     // at least once in partitions. We move each value from aggregates_out
436     // to the corresponding position in partition_count. This will handle
437     // possibly empty parts.
438     MoveValues(device, unique_out_ptr, aggregates_out_ptr, num_runs_ptr,
439                num_partitions_, partition_count->flat<int32>().data());
440   }  // At this point indices_in, partitions_out, aggregates_out
441      // and cub_temp_storage will be marked for deallocation.
442 
GatherSlices(OpKernelContext * c,const Tensor * data,const Tensor * indices,int32 N,int64 slice_size,OpOutputList & outs)443   void GatherSlices(OpKernelContext* c, const Tensor* data,
444                     const Tensor* indices, int32 N, int64 slice_size,
445                     OpOutputList& outs) {
446     const GPUDevice& device = c->eigen_device<GPUDevice>();
447     const int32* ind_base = indices->flat<int32>().data();
448     const T* data_base = data->flat<T>().data();
449 
450     for (int p = 0; p < num_partitions_; p++) {
451       int32 indices_size = outs[p]->dim_size(0);
452       int64 out_size = outs[p]->NumElements();
453       T* out_base = outs[p]->flat<T>().data();
454       if (out_size > 0)
455         TF_CHECK_OK(LaunchGatherKernel</*is_axis_zero = */ true>(
456             device, data_base, ind_base, out_base, N, indices_size, slice_size,
457             out_size));
458       ind_base += indices_size;
459     }
460   }
461 
462   int32 num_partitions_;
463 };
464 
465 #define REGISTER_DYNAMIC_PARTITION_GPU(T)                                 \
466   REGISTER_KERNEL_BUILDER(                                                \
467       Name("DynamicPartition").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
468       DynamicPartitionOpGPU<T>)
469 
470 TF_CALL_int32(REGISTER_DYNAMIC_PARTITION_GPU);
471 TF_CALL_int64(REGISTER_DYNAMIC_PARTITION_GPU);
472 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
473 TF_CALL_COMPLEX_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
474 #undef REGISTER_DYNAMIC_PARTITION_GPU
475 
476 }  // namespace tensorflow
477 
478 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
479