xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cwise_op_select.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/kernels/cwise_ops_common.h"
25 #include "tensorflow/core/platform/prefetch.h"
26 
27 namespace tensorflow {
28 
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 
33 namespace functor {
34 template <typename Device, typename T>
35 struct SelectScalarHandler;
36 }  // namespace functor
37 
38 template <typename Device, typename T>
39 class SelectOp : public OpKernel {
40  public:
SelectOp(OpKernelConstruction * context)41   explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
42 
Compute(OpKernelContext * ctx)43   void Compute(OpKernelContext* ctx) override {
44     const Tensor* cond = &ctx->input(0);
45     const Tensor* then = &ctx->input(1);
46     const Tensor* else_ = &ctx->input(2);
47 
48     if (TensorShapeUtils::IsScalar(cond->shape())) {
49       ComputeScalar(ctx, cond, then, else_);
50       return;
51     }
52 
53     bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
54                          !TensorShapeUtils::IsVector(then->shape()));
55 
56     if (broadcasting) {
57       ComputeBroadcasting(ctx, cond, then, else_);
58     } else {
59       ComputeElementwise(ctx, cond, then, else_);
60     }
61   }
62 
63  protected:
ComputeBroadcasting(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)64   void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond,
65                            const Tensor* then, const Tensor* else_) {
66     // Preliminary validation of sizes.
67     OP_REQUIRES(
68         ctx, TensorShapeUtils::IsVector(cond->shape()),
69         errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
70                                 cond->shape().DebugString()));
71     OP_REQUIRES(
72         ctx,
73         FastBoundsCheck(cond->NumElements(),
74                         std::numeric_limits<Eigen::DenseIndex>::max()),
75         errors::InvalidArgument("cond vector larger than ",
76                                 std::numeric_limits<Eigen::DenseIndex>::max()));
77     OP_REQUIRES(
78         ctx,
79         FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
80                         std::numeric_limits<Eigen::DenseIndex>::max()),
81         errors::InvalidArgument("flat outer dims dim 1 size >= ",
82                                 std::numeric_limits<Eigen::DenseIndex>::max()));
83 
84     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()),
85                 errors::InvalidArgument(
86                     "'then' must be at least a vector, but saw shape: ",
87                     then->shape().DebugString()));
88     OP_REQUIRES(
89         ctx, then->shape().dim_size(0) == cond->NumElements(),
90         errors::InvalidArgument(
91             "Number of batches of 'then' must match size of 'cond', but saw: ",
92             then->shape().dim_size(0), " vs. ", cond->NumElements()));
93     OP_REQUIRES(
94         ctx, then->shape().IsSameSize(else_->shape()),
95         errors::InvalidArgument(
96             "'then' and 'else' must have the same size.  but received: ",
97             then->shape().DebugString(), " vs. ",
98             else_->shape().DebugString()));
99 
100     Tensor* output = nullptr;
101     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
102                             {"t", "e"}, "output", then->shape(), &output));
103     if (output->NumElements() > 0) {
104       functor::BatchSelectFunctor<Device, T> func;
105       func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
106            cond->vec<bool>(), then->flat_outer_dims<T>(),
107            else_->flat_outer_dims<T>());
108     }
109   }
110 
ComputeElementwise(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)111   void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
112                           const Tensor* then, const Tensor* else_) {
113     if (!ctx->ValidateInputsAreSameShape(this)) return;
114     Tensor* output = nullptr;
115     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
116                             {"t", "e"}, "output", then->shape(), &output));
117     if (output->NumElements() > 0) {
118       functor::SelectFunctor<Device, T> func;
119       func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
120            then->flat<T>(), else_->flat<T>());
121     }
122   }
123 
ComputeScalar(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)124   void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
125                      const Tensor* then, const Tensor* else_) {
126     OP_REQUIRES(
127         ctx, then->shape().IsSameSize(else_->shape()),
128         errors::InvalidArgument(
129             "'then' and 'else' must have the same size.  but received: ",
130             then->shape().DebugString(), " vs. ",
131             else_->shape().DebugString()));
132 
133     functor::SelectScalarHandler<Device, T> handler;
134     handler(ctx, cond, then, else_);
135   }
136 
137  private:
138   TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
139 };
140 template <typename Device, typename T>
141 class SelectV2Op : public OpKernel {
142  public:
SelectV2Op(OpKernelConstruction * context)143   explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
144 
Compute(OpKernelContext * ctx)145   void Compute(OpKernelContext* ctx) override {
146     const Tensor* cond = &ctx->input(0);
147     const Tensor* then = &ctx->input(1);
148     const Tensor* else_ = &ctx->input(2);
149 
150     // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
151     // This matches the behavior of numpy.
152     BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(),
153                         else_->shape().dim_sizes()},
154                        false);
155     OP_REQUIRES(ctx, bcast.IsValid(),
156                 errors::InvalidArgument(
157                     "condition ", cond->shape().DebugString(), ", then ",
158                     then->shape().DebugString(), ", and else ",
159                     else_->shape().DebugString(), " must be broadcastable"));
160 
161     // Broadcast `cond`, `then` and `else` to combined shape,
162     // in order to obtain the reshape.
163     BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false);
164     BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false);
165     BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false);
166     OP_REQUIRES(
167         ctx,
168         cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
169         errors::InvalidArgument("condition ", cond->shape().DebugString(),
170                                 ", then ", then->shape().DebugString(),
171                                 ", and else ", else_->shape().DebugString(),
172                                 " must be broadcastable"));
173 
174     // Combined shape should be the final shape.
175     OP_REQUIRES(
176         ctx,
177         cond_bcast.output_shape() == bcast.output_shape() &&
178             then_bcast.output_shape() == bcast.output_shape() &&
179             else_bcast.output_shape() == bcast.output_shape(),
180         errors::InvalidArgument("condition ", cond->shape().DebugString(),
181                                 ", then ", then->shape().DebugString(),
182                                 ", and else ", else_->shape().DebugString(),
183                                 " must be broadcastable to the same shape"));
184 
185     Tensor* output = nullptr;
186     const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
187     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
188                             {"t", "e"}, "output", output_shape, &output));
189 
190     if (output->NumElements() == 0) {
191       return;
192     }
193 
194 #define HANDLE_DIM(NDIMS)                                            \
195   {                                                                  \
196     functor::BCastSelectFunctor<Device, T, NDIMS> func;              \
197     func(ctx->eigen_device<Device>(),                                \
198          output->shaped<T, NDIMS>(bcast.result_shape()),             \
199          cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
200          then->template shaped<T, NDIMS>(then_bcast.y_reshape()),    \
201          else_->template shaped<T, NDIMS>(else_bcast.y_reshape()),   \
202          BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()),           \
203          BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()),           \
204          BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast()));          \
205   }
206 
207     const int ndims = static_cast<int>(bcast.result_shape().size());
208     switch (ndims) {
209       case 1:
210         HANDLE_DIM(1);
211         break;
212       case 2:
213         HANDLE_DIM(2);
214         break;
215       case 3:
216         HANDLE_DIM(3);
217         break;
218       case 4:
219         HANDLE_DIM(4);
220         break;
221       case 5:
222         HANDLE_DIM(5);
223         break;
224       case 6:
225         HANDLE_DIM(6);
226         break;
227       case 7:
228         HANDLE_DIM(7);
229         break;
230       case 8:
231         HANDLE_DIM(8);
232         break;
233       default:
234         ctx->SetStatus(errors::Unimplemented(
235             "Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
236             ctx->input(1).shape().DebugString(), " is not supported yet."));
237         break;
238     }
239   }
240 
241  private:
242   TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
243 };
244 
245 #define REGISTER_SELECT(type)                                        \
246   REGISTER_KERNEL_BUILDER(                                           \
247       Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
248       SelectOp<CPUDevice, type>);                                    \
249   REGISTER_KERNEL_BUILDER(                                           \
250       Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
251       SelectV2Op<CPUDevice, type>);
252 
253 TF_CALL_ALL_TYPES(REGISTER_SELECT);
254 
255 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
256 
257 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
258 
259 // Registration of the GPU implementations.
260 #define REGISTER_SELECT_GPU(type)                                    \
261   REGISTER_KERNEL_BUILDER(                                           \
262       Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
263       SelectOp<GPUDevice, type>);                                    \
264   REGISTER_KERNEL_BUILDER(                                           \
265       Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
266       SelectV2Op<GPUDevice, type>);
267 
268 REGISTER_SELECT_GPU(bool);
269 REGISTER_SELECT_GPU(Eigen::half);
270 REGISTER_SELECT_GPU(float);
271 REGISTER_SELECT_GPU(double);
272 REGISTER_SELECT_GPU(int32);
273 REGISTER_SELECT_GPU(int64);
274 REGISTER_SELECT_GPU(complex64);
275 REGISTER_SELECT_GPU(complex128);
276 
277 #undef REGISTER_SELECT_GPU
278 
279 #else
280 
281 #define REGISTER_SELECT_GPU(type)                                  \
282   REGISTER_KERNEL_BUILDER(                                         \
283       Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
284       SelectOp<GPUDevice, type>);
285 
286 REGISTER_SELECT_GPU(bool);
287 REGISTER_SELECT_GPU(Eigen::half);
288 REGISTER_SELECT_GPU(float);
289 REGISTER_SELECT_GPU(double);
290 REGISTER_SELECT_GPU(int32);
291 REGISTER_SELECT_GPU(int64_t);
292 REGISTER_SELECT_GPU(complex64);
293 REGISTER_SELECT_GPU(complex128);
294 
295 #undef REGISTER_SELECT_GPU
296 #endif
297 
298 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
299 
300 
301 namespace functor {
302 
303 // CPU Specializations of Select functors.
304 template <typename Device, typename T>
305 struct SelectFunctorBase {
operator ()tensorflow::functor::SelectFunctorBase306   void operator()(const Device& d, typename TTypes<T>::Flat out,
307                   typename TTypes<bool>::ConstFlat cond_flat,
308                   typename TTypes<T>::ConstFlat then_flat,
309                   typename TTypes<T>::ConstFlat else_flat) {
310     Assign(d, out, cond_flat.select(then_flat, else_flat));
311   }
312 };
313 
314 template <typename T>
315 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {};
316 
317 template <typename Device, typename T>
318 struct SelectScalarHandler {
operator ()tensorflow::functor::SelectScalarHandler319   void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
320                   const Tensor* else_) {
321     Tensor* output = nullptr;
322     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
323                             {"t", "e"}, "output", then->shape(), &output));
324 
325     if (output->NumElements() > 0) {
326       functor::SelectScalarFunctor<Device, T> func;
327       TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
328       func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
329            then->flat<T>(), else_->flat<T>());
330     }
331   }
332 };
333 
334 // Specialization for CPU device. Forward input to output depending on the
335 // `cond` value.
336 // TODO(sjhwang): Consider specializing for GPUDevice as well by using
337 // GPUDevice::memcpyDeviceToHost() to fetch bool value.
338 template <typename T>
339 struct SelectScalarHandler<CPUDevice, T> {
operator ()tensorflow::functor::SelectScalarHandler340   void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
341                   const Tensor* else_) {
342     if (cond->scalar<bool>()()) {
343       OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
344     } else {
345       OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
346     }
347   }
348 };
349 
350 
351 template <typename Device, typename T>
352 struct BatchSelectFunctorBase {
operator ()tensorflow::functor::BatchSelectFunctorBase353   void operator()(const Device& d,
354                   typename TTypes<T>::Matrix output_flat_outer_dims,
355                   TTypes<bool>::ConstVec cond_vec,
356                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
357                   typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
358     const Eigen::DenseIndex batch = cond_vec.size();
359     const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1);
360 
361     Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims;
362     broadcast_dims.set(1, all_but_batch);
363     Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims;
364     reshape_dims.set(0, batch);
365 
366     Assign(d, output_flat_outer_dims,
367            cond_vec.reshape(reshape_dims)
368                .broadcast(broadcast_dims)
369                .select(then_flat_outer_dims, else_flat_outer_dims));
370   }
371 };
372 
373 // A fast implementation on CPU, using loop to get rid of broadcasting.
374 template <typename T>
375 struct BatchSelectFunctor<CPUDevice, T> {
operator ()tensorflow::functor::BatchSelectFunctor376   void operator()(const CPUDevice& d,
377                   typename TTypes<T>::Matrix output_flat_outer_dims,
378                   TTypes<bool>::ConstVec cond_vec,
379                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
380                   typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
381     const size_t batch = cond_vec.size();
382     const size_t batch_size = then_flat_outer_dims.size() / batch;
383     T* output = output_flat_outer_dims.data();
384     const bool* c = cond_vec.data();
385     const T* t = then_flat_outer_dims.data();
386     const T* e = else_flat_outer_dims.data();
387 
388     auto work = [batch_size, output, c, t, e](int64_t start, int64_t end) {
389       for (size_t i = start; i < end; ++i) {
390         size_t offset = i * batch_size;
391         port::prefetch<port::PREFETCH_HINT_NTA>(
392             reinterpret_cast<const void*>(&t[offset + batch_size]));
393         port::prefetch<port::PREFETCH_HINT_NTA>(
394             reinterpret_cast<const void*>(&e[offset + batch_size]));
395         port::prefetch<port::PREFETCH_HINT_NTA>(
396             reinterpret_cast<const void*>(&c[i + 1]));
397         if (c[i]) {
398           for (size_t j = 0; j < batch_size; ++j) {
399             output[offset + j] = t[offset + j];
400           }
401         } else {
402           for (size_t j = 0; j < batch_size; ++j) {
403             output[offset + j] = e[offset + j];
404           }
405         }
406       }
407     };
408     auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2,  // ld bytes
409                                     sizeof(T) * batch_size,      // st bytes
410                                     batch_size);  // compute cycles
411     d.parallelFor(batch, cost, work);
412   }
413 };
414 
415 template <typename Device, typename T, int NDIMS>
416 struct BCastSelectFunctorBase {
operator ()tensorflow::functor::BCastSelectFunctorBase417   void operator()(const Device& d,
418                   typename TTypes<T, NDIMS>::Tensor output_tensor,
419                   typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
420                   typename TTypes<T, NDIMS>::ConstTensor then_tensor,
421                   typename TTypes<T, NDIMS>::ConstTensor else_tensor,
422                   typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
423                   typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
424                   typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
425     output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
426                                   .select(then_tensor.broadcast(then_bcast),
427                                           else_tensor.broadcast(else_bcast));
428   }
429 };
430 
431 template <typename T, int NDIMS>
432 struct BCastSelectFunctor<CPUDevice, T, NDIMS>
433     : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
434 
435 
436 }  // namespace functor
437 
438 }  // namespace tensorflow
439