1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 21 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/kernels/cwise_ops_common.h" 25 #include "tensorflow/core/platform/prefetch.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef Eigen::GpuDevice GPUDevice; 31 32 33 namespace functor { 34 template <typename Device, typename T> 35 struct SelectScalarHandler; 36 } // namespace functor 37 38 template <typename Device, typename T> 39 class SelectOp : public OpKernel { 40 public: SelectOp(OpKernelConstruction * context)41 explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} 42 Compute(OpKernelContext * ctx)43 void Compute(OpKernelContext* ctx) override { 44 const Tensor* cond = &ctx->input(0); 45 const Tensor* then = &ctx->input(1); 46 const Tensor* else_ = &ctx->input(2); 47 48 if (TensorShapeUtils::IsScalar(cond->shape())) { 49 ComputeScalar(ctx, cond, then, else_); 50 return; 51 } 52 53 bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && 54 !TensorShapeUtils::IsVector(then->shape())); 55 56 if (broadcasting) { 57 ComputeBroadcasting(ctx, cond, then, else_); 58 } else { 59 ComputeElementwise(ctx, cond, then, else_); 60 } 61 } 62 63 protected: ComputeBroadcasting(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)64 void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond, 65 const Tensor* then, const Tensor* else_) { 66 // Preliminary validation of sizes. 67 OP_REQUIRES( 68 ctx, TensorShapeUtils::IsVector(cond->shape()), 69 errors::InvalidArgument("'cond' must be a vector, but saw shape: ", 70 cond->shape().DebugString())); 71 OP_REQUIRES( 72 ctx, 73 FastBoundsCheck(cond->NumElements(), 74 std::numeric_limits<Eigen::DenseIndex>::max()), 75 errors::InvalidArgument("cond vector larger than ", 76 std::numeric_limits<Eigen::DenseIndex>::max())); 77 OP_REQUIRES( 78 ctx, 79 FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), 80 std::numeric_limits<Eigen::DenseIndex>::max()), 81 errors::InvalidArgument("flat outer dims dim 1 size >= ", 82 std::numeric_limits<Eigen::DenseIndex>::max())); 83 84 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()), 85 errors::InvalidArgument( 86 "'then' must be at least a vector, but saw shape: ", 87 then->shape().DebugString())); 88 OP_REQUIRES( 89 ctx, then->shape().dim_size(0) == cond->NumElements(), 90 errors::InvalidArgument( 91 "Number of batches of 'then' must match size of 'cond', but saw: ", 92 then->shape().dim_size(0), " vs. ", cond->NumElements())); 93 OP_REQUIRES( 94 ctx, then->shape().IsSameSize(else_->shape()), 95 errors::InvalidArgument( 96 "'then' and 'else' must have the same size. but received: ", 97 then->shape().DebugString(), " vs. ", 98 else_->shape().DebugString())); 99 100 Tensor* output = nullptr; 101 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 102 {"t", "e"}, "output", then->shape(), &output)); 103 if (output->NumElements() > 0) { 104 functor::BatchSelectFunctor<Device, T> func; 105 func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), 106 cond->vec<bool>(), then->flat_outer_dims<T>(), 107 else_->flat_outer_dims<T>()); 108 } 109 } 110 ComputeElementwise(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)111 void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond, 112 const Tensor* then, const Tensor* else_) { 113 if (!ctx->ValidateInputsAreSameShape(this)) return; 114 Tensor* output = nullptr; 115 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 116 {"t", "e"}, "output", then->shape(), &output)); 117 if (output->NumElements() > 0) { 118 functor::SelectFunctor<Device, T> func; 119 func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), 120 then->flat<T>(), else_->flat<T>()); 121 } 122 } 123 ComputeScalar(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)124 void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, 125 const Tensor* then, const Tensor* else_) { 126 OP_REQUIRES( 127 ctx, then->shape().IsSameSize(else_->shape()), 128 errors::InvalidArgument( 129 "'then' and 'else' must have the same size. but received: ", 130 then->shape().DebugString(), " vs. ", 131 else_->shape().DebugString())); 132 133 functor::SelectScalarHandler<Device, T> handler; 134 handler(ctx, cond, then, else_); 135 } 136 137 private: 138 TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); 139 }; 140 template <typename Device, typename T> 141 class SelectV2Op : public OpKernel { 142 public: SelectV2Op(OpKernelConstruction * context)143 explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} 144 Compute(OpKernelContext * ctx)145 void Compute(OpKernelContext* ctx) override { 146 const Tensor* cond = &ctx->input(0); 147 const Tensor* then = &ctx->input(1); 148 const Tensor* else_ = &ctx->input(2); 149 150 // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), 151 // This matches the behavior of numpy. 152 BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(), 153 else_->shape().dim_sizes()}, 154 false); 155 OP_REQUIRES(ctx, bcast.IsValid(), 156 errors::InvalidArgument( 157 "condition ", cond->shape().DebugString(), ", then ", 158 then->shape().DebugString(), ", and else ", 159 else_->shape().DebugString(), " must be broadcastable")); 160 161 // Broadcast `cond`, `then` and `else` to combined shape, 162 // in order to obtain the reshape. 163 BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false); 164 BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false); 165 BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false); 166 OP_REQUIRES( 167 ctx, 168 cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(), 169 errors::InvalidArgument("condition ", cond->shape().DebugString(), 170 ", then ", then->shape().DebugString(), 171 ", and else ", else_->shape().DebugString(), 172 " must be broadcastable")); 173 174 // Combined shape should be the final shape. 175 OP_REQUIRES( 176 ctx, 177 cond_bcast.output_shape() == bcast.output_shape() && 178 then_bcast.output_shape() == bcast.output_shape() && 179 else_bcast.output_shape() == bcast.output_shape(), 180 errors::InvalidArgument("condition ", cond->shape().DebugString(), 181 ", then ", then->shape().DebugString(), 182 ", and else ", else_->shape().DebugString(), 183 " must be broadcastable to the same shape")); 184 185 Tensor* output = nullptr; 186 const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); 187 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 188 {"t", "e"}, "output", output_shape, &output)); 189 190 if (output->NumElements() == 0) { 191 return; 192 } 193 194 #define HANDLE_DIM(NDIMS) \ 195 { \ 196 functor::BCastSelectFunctor<Device, T, NDIMS> func; \ 197 func(ctx->eigen_device<Device>(), \ 198 output->shaped<T, NDIMS>(bcast.result_shape()), \ 199 cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \ 200 then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \ 201 else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \ 202 BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \ 203 BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \ 204 BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \ 205 } 206 207 const int ndims = static_cast<int>(bcast.result_shape().size()); 208 switch (ndims) { 209 case 1: 210 HANDLE_DIM(1); 211 break; 212 case 2: 213 HANDLE_DIM(2); 214 break; 215 case 3: 216 HANDLE_DIM(3); 217 break; 218 case 4: 219 HANDLE_DIM(4); 220 break; 221 case 5: 222 HANDLE_DIM(5); 223 break; 224 case 6: 225 HANDLE_DIM(6); 226 break; 227 case 7: 228 HANDLE_DIM(7); 229 break; 230 case 8: 231 HANDLE_DIM(8); 232 break; 233 default: 234 ctx->SetStatus(errors::Unimplemented( 235 "Broadcast between ", ctx->input(0).shape().DebugString(), " and ", 236 ctx->input(1).shape().DebugString(), " is not supported yet.")); 237 break; 238 } 239 } 240 241 private: 242 TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); 243 }; 244 245 #define REGISTER_SELECT(type) \ 246 REGISTER_KERNEL_BUILDER( \ 247 Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 248 SelectOp<CPUDevice, type>); \ 249 REGISTER_KERNEL_BUILDER( \ 250 Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 251 SelectV2Op<CPUDevice, type>); 252 253 TF_CALL_ALL_TYPES(REGISTER_SELECT); 254 255 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 256 257 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) 258 259 // Registration of the GPU implementations. 260 #define REGISTER_SELECT_GPU(type) \ 261 REGISTER_KERNEL_BUILDER( \ 262 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 263 SelectOp<GPUDevice, type>); \ 264 REGISTER_KERNEL_BUILDER( \ 265 Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 266 SelectV2Op<GPUDevice, type>); 267 268 REGISTER_SELECT_GPU(bool); 269 REGISTER_SELECT_GPU(Eigen::half); 270 REGISTER_SELECT_GPU(float); 271 REGISTER_SELECT_GPU(double); 272 REGISTER_SELECT_GPU(int32); 273 REGISTER_SELECT_GPU(int64); 274 REGISTER_SELECT_GPU(complex64); 275 REGISTER_SELECT_GPU(complex128); 276 277 #undef REGISTER_SELECT_GPU 278 279 #else 280 281 #define REGISTER_SELECT_GPU(type) \ 282 REGISTER_KERNEL_BUILDER( \ 283 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 284 SelectOp<GPUDevice, type>); 285 286 REGISTER_SELECT_GPU(bool); 287 REGISTER_SELECT_GPU(Eigen::half); 288 REGISTER_SELECT_GPU(float); 289 REGISTER_SELECT_GPU(double); 290 REGISTER_SELECT_GPU(int32); 291 REGISTER_SELECT_GPU(int64_t); 292 REGISTER_SELECT_GPU(complex64); 293 REGISTER_SELECT_GPU(complex128); 294 295 #undef REGISTER_SELECT_GPU 296 #endif 297 298 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 299 300 301 namespace functor { 302 303 // CPU Specializations of Select functors. 304 template <typename Device, typename T> 305 struct SelectFunctorBase { operator ()tensorflow::functor::SelectFunctorBase306 void operator()(const Device& d, typename TTypes<T>::Flat out, 307 typename TTypes<bool>::ConstFlat cond_flat, 308 typename TTypes<T>::ConstFlat then_flat, 309 typename TTypes<T>::ConstFlat else_flat) { 310 Assign(d, out, cond_flat.select(then_flat, else_flat)); 311 } 312 }; 313 314 template <typename T> 315 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {}; 316 317 template <typename Device, typename T> 318 struct SelectScalarHandler { operator ()tensorflow::functor::SelectScalarHandler319 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, 320 const Tensor* else_) { 321 Tensor* output = nullptr; 322 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 323 {"t", "e"}, "output", then->shape(), &output)); 324 325 if (output->NumElements() > 0) { 326 functor::SelectScalarFunctor<Device, T> func; 327 TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); 328 func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, 329 then->flat<T>(), else_->flat<T>()); 330 } 331 } 332 }; 333 334 // Specialization for CPU device. Forward input to output depending on the 335 // `cond` value. 336 // TODO(sjhwang): Consider specializing for GPUDevice as well by using 337 // GPUDevice::memcpyDeviceToHost() to fetch bool value. 338 template <typename T> 339 struct SelectScalarHandler<CPUDevice, T> { operator ()tensorflow::functor::SelectScalarHandler340 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, 341 const Tensor* else_) { 342 if (cond->scalar<bool>()()) { 343 OP_REQUIRES_OK(ctx, ctx->set_output("output", *then)); 344 } else { 345 OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_)); 346 } 347 } 348 }; 349 350 351 template <typename Device, typename T> 352 struct BatchSelectFunctorBase { operator ()tensorflow::functor::BatchSelectFunctorBase353 void operator()(const Device& d, 354 typename TTypes<T>::Matrix output_flat_outer_dims, 355 TTypes<bool>::ConstVec cond_vec, 356 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 357 typename TTypes<T>::ConstMatrix else_flat_outer_dims) { 358 const Eigen::DenseIndex batch = cond_vec.size(); 359 const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1); 360 361 Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims; 362 broadcast_dims.set(1, all_but_batch); 363 Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims; 364 reshape_dims.set(0, batch); 365 366 Assign(d, output_flat_outer_dims, 367 cond_vec.reshape(reshape_dims) 368 .broadcast(broadcast_dims) 369 .select(then_flat_outer_dims, else_flat_outer_dims)); 370 } 371 }; 372 373 // A fast implementation on CPU, using loop to get rid of broadcasting. 374 template <typename T> 375 struct BatchSelectFunctor<CPUDevice, T> { operator ()tensorflow::functor::BatchSelectFunctor376 void operator()(const CPUDevice& d, 377 typename TTypes<T>::Matrix output_flat_outer_dims, 378 TTypes<bool>::ConstVec cond_vec, 379 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 380 typename TTypes<T>::ConstMatrix else_flat_outer_dims) { 381 const size_t batch = cond_vec.size(); 382 const size_t batch_size = then_flat_outer_dims.size() / batch; 383 T* output = output_flat_outer_dims.data(); 384 const bool* c = cond_vec.data(); 385 const T* t = then_flat_outer_dims.data(); 386 const T* e = else_flat_outer_dims.data(); 387 388 auto work = [batch_size, output, c, t, e](int64_t start, int64_t end) { 389 for (size_t i = start; i < end; ++i) { 390 size_t offset = i * batch_size; 391 port::prefetch<port::PREFETCH_HINT_NTA>( 392 reinterpret_cast<const void*>(&t[offset + batch_size])); 393 port::prefetch<port::PREFETCH_HINT_NTA>( 394 reinterpret_cast<const void*>(&e[offset + batch_size])); 395 port::prefetch<port::PREFETCH_HINT_NTA>( 396 reinterpret_cast<const void*>(&c[i + 1])); 397 if (c[i]) { 398 for (size_t j = 0; j < batch_size; ++j) { 399 output[offset + j] = t[offset + j]; 400 } 401 } else { 402 for (size_t j = 0; j < batch_size; ++j) { 403 output[offset + j] = e[offset + j]; 404 } 405 } 406 } 407 }; 408 auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes 409 sizeof(T) * batch_size, // st bytes 410 batch_size); // compute cycles 411 d.parallelFor(batch, cost, work); 412 } 413 }; 414 415 template <typename Device, typename T, int NDIMS> 416 struct BCastSelectFunctorBase { operator ()tensorflow::functor::BCastSelectFunctorBase417 void operator()(const Device& d, 418 typename TTypes<T, NDIMS>::Tensor output_tensor, 419 typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, 420 typename TTypes<T, NDIMS>::ConstTensor then_tensor, 421 typename TTypes<T, NDIMS>::ConstTensor else_tensor, 422 typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, 423 typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, 424 typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) { 425 output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) 426 .select(then_tensor.broadcast(then_bcast), 427 else_tensor.broadcast(else_bcast)); 428 } 429 }; 430 431 template <typename T, int NDIMS> 432 struct BCastSelectFunctor<CPUDevice, T, NDIMS> 433 : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {}; 434 435 436 } // namespace functor 437 438 } // namespace tensorflow 439