xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gather_functor_gpu.cu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/kernels/gather_functor.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 
28 namespace tensorflow {
29 
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 template <typename ValueOrVec, typename Index, bool is_axis_zero>
GatherOpKernel(const ValueOrVec * __restrict__ params,const Index * __restrict__ indices,ValueOrVec * __restrict__ out,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)33 __global__ void GatherOpKernel(const ValueOrVec* __restrict__ params,
34                                const Index* __restrict__ indices,
35                                ValueOrVec* __restrict__ out,
36                                int64 gather_dim_size, int64 indices_size,
37                                int64 slice_size, int64 out_size) {
38   GPU_1D_KERNEL_LOOP(i, out_size) {
39     Index batch_i = 0;
40     Index indices_i = 0;
41     Index slice_i = 0;
42     if (is_axis_zero) {
43       indices_i = i / slice_size;
44       slice_i = i - indices_i * slice_size;
45     } else {
46       Index batch_indices_i = i / slice_size;
47       // The batch index into params to use for i.
48       batch_i = batch_indices_i / indices_size;
49       // The index into indices to use for i.
50       indices_i = batch_indices_i - batch_i * indices_size;
51       // Index into the current slice in params to use for i.
52       slice_i = i - batch_indices_i * slice_size;
53     }
54 
55     // Index into the gather axis to use for i.
56     Index gather_i = ldg(indices + indices_i);
57 
58     // Check gather_i is in [0, gather_dim_size).
59     if (!FastBoundsCheck(gather_i, gather_dim_size)) {
60       // Set indices out of range to zero
61       // TODO(fpmc): Log an error for transfer back to host.
62       out[i] = ValueOrVec(0);
63     } else {
64       // params is a [batch_size, gather_dim_size, slice_size] tensor. Read
65       // params[batch_i, gather_i, slice_i] and write it to the i'th position in
66       // out.
67       Index params_i =
68           (batch_i * gather_dim_size + gather_i) * slice_size + slice_i;
69       out[i] = params[params_i];
70     }
71   }
72 }
73 
74 namespace detail {
75 
76 template <bool is_axis_zero>
77 struct LaunchGatherKernelVectorized {
78   template <int vec_size>
79   struct Impl {
80     template <typename T, typename Index>
operatorLaunchGatherKernelVectorized::Impl81     Status operator()(const GPUDevice& d, const T* params, const Index* indices,
82                       T* out, int64 gather_dim_size, int64 indices_size,
83                       int64 slice_size, int64 out_size) {
84       DCHECK_EQ(slice_size % vec_size, 0);
85       DCHECK_EQ(out_size % vec_size, 0);
86       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(params) % vec_size, 0);
87       DCHECK_EQ(reinterpret_cast<std::uintptr_t>(out) % vec_size, 0);
88       int64 out_size_vec = out_size / vec_size;
89       int64 slice_size_vec = slice_size / vec_size;
90       using Tvec = AlignedVector<T, vec_size>;
91       const Tvec* params_vec = reinterpret_cast<const Tvec*>(params);
92       Tvec* out_vec = reinterpret_cast<Tvec*>(out);
93 
94       GpuLaunchConfig config = GetGpuLaunchConfig(
95           out_size_vec, d, &GatherOpKernel<Tvec, Index, is_axis_zero>,
96           /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
97       return GpuLaunchKernel(
98           GatherOpKernel<Tvec, Index, is_axis_zero>, config.block_count,
99           config.thread_per_block, 0, d.stream(), params_vec, indices, out_vec,
100           gather_dim_size, indices_size, slice_size_vec, out_size_vec);
101     }
102   };
103 };
104 
105 }  // namespace detail
106 
107 template <bool is_axis_zero, typename T, typename Index>
LaunchGatherKernel(const GPUDevice & d,const T * params,const Index * indices,T * out,int64 gather_dim_size,int64 indices_size,int64 slice_size,int64 out_size)108 Status LaunchGatherKernel(const GPUDevice& d, const T* params,
109                           const Index* indices, T* out, int64 gather_dim_size,
110                           int64 indices_size, int64 slice_size,
111                           int64 out_size) {
112   // Note that the GPU memory allocator always returns aligned buffers, so the
113   // alignment of data pointers is expected to be deterministic.
114   // There will be performance cliffs when slice_size is not aligned, but there
115   // is no easy way to handle the misalignment because each row will be aligned
116   // differently.
117   return DispatchToVectorized<
118       T, detail::LaunchGatherKernelVectorized<is_axis_zero>::template Impl>(
119       MinAlignmentOf(params, out, slice_size), d, params, indices, out,
120       gather_dim_size, indices_size, slice_size, out_size);
121 }
122 
123 namespace functor {
124 template <typename T, typename Index>
125 struct GatherFunctor<GPUDevice, T, Index> {
126   int64 operator()(OpKernelContext* ctx,
127                    typename TTypes<T, 3>::ConstTensor params,
128                    typename TTypes<Index>::ConstFlat indices,
129                    typename TTypes<T, 3>::Tensor out) {
130     const GPUDevice& d = ctx->eigen_gpu_device();
131     const int64 out_size = out.size();
132     if (out_size == 0) {
133       // We need a check here since the CPU version does useful error checking
134       // work if there are nonempty indices but empty slices, so the kernel is
135       // executed in that case.  In the GPU case we don't know how to do error
136       // checking, so we skip the loop entirely.
137       return -1;
138     }
139     const bool is_axis_zero = params.dimension(0) == 1;
140     const int64 gather_dim_size = params.dimension(1);
141     const int64 indices_size = indices.size();
142     const int64 slice_size = params.dimension(2);
143 
144     if (is_axis_zero) {
145       TF_CHECK_OK(LaunchGatherKernel<true>(d, params.data(), indices.data(),
146                                            out.data(), gather_dim_size,
147                                            indices_size, slice_size, out_size));
148     } else {
149       TF_CHECK_OK(LaunchGatherKernel<false>(
150           d, params.data(), indices.data(), out.data(), gather_dim_size,
151           indices_size, slice_size, out_size));
152     }
153     // TODO(fpmc): enable indices validation on GPU.
154     // Right now checking for indices out of bound in the kernel would
155     // require copying code between GPU/CPU, and thus slow.
156     return -1;
157   }
158 };
159 
160 }  // namespace functor
161 }  // namespace tensorflow
162 
163 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
164 
165 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
166