xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cwise_ops_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18 
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/platform/bfloat16.h"
26 
27 
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/cwise_ops.h"
33 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/util/bcast.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 class BinaryOpShared : public OpKernel {
44  public:
45   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
46 
47  protected:
48   struct BinaryOpState {
49     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
50     // is valid, and if so, set out, either by allocating a new buffer using
51     // ctx->output(...) or by creating an alias for an owned input buffer for
52     // in-place computation.
53     // Caller must check ctx->status() upon return for non-ok status.
54     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
55     explicit BinaryOpState(OpKernelContext* ctx);
56 
57     const Tensor& in0;
58     const Tensor& in1;
59 
60     BCast bcast;
61     Tensor* out = nullptr;
62     int64_t out_num_elements;
63 
64     int64_t in0_num_elements;
65     int64_t in1_num_elements;
66 
67     int ndims;
68     bool result;
69   };
70 
71   void SetUnimplementedError(OpKernelContext* ctx);
72   void SetComputeError(OpKernelContext* ctx);
73 };
74 
75 // Coefficient-wise binary operations:
76 //   Device: E.g., CPUDevice, GPUDevice.
77 //   Functor: defined in cwise_ops.h. E.g., functor::add.
78 template <typename Device, typename Functor>
79 class BinaryOp : public BinaryOpShared {
80  public:
81   typedef typename Functor::in_type Tin;    // Input scalar data type.
82   typedef typename Functor::out_type Tout;  // Output scalar data type.
83 
BinaryOp(OpKernelConstruction * ctx)84   explicit BinaryOp(OpKernelConstruction* ctx)
85       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
86                        DataTypeToEnum<Tin>::v()) {}
87 
Compute(OpKernelContext * ctx)88   void Compute(OpKernelContext* ctx) override {
89     const Tensor& input_0 = ctx->input(0);
90     OP_REQUIRES(ctx, input_0.dtype() == DataTypeToEnum<Tin>::v(),
91                 errors::InvalidArgument(
92                     "Expected tensor of type ",
93                     DataTypeString(DataTypeToEnum<Tin>::v()), " but got type ",
94                     DataTypeString(input_0.dtype())));
95     const Tensor& input_1 = ctx->input(1);
96     OP_REQUIRES(ctx, input_1.dtype() == DataTypeToEnum<Tin>::v(),
97                 errors::InvalidArgument(
98                     "Expected tensor of type ",
99                     DataTypeString(DataTypeToEnum<Tin>::v()), " but got type ",
100                     DataTypeString(input_1.dtype())));
101     const Device& eigen_device = ctx->eigen_device<Device>();
102     bool error = false;
103     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
104 
105     // NOTE: Handle three simple cases before building the BinaryOpState, which
106     // is relatively expensive for small operations.
107     if (input_0.shape() == input_1.shape()) {
108       // tensor op tensor with no broadcasting.
109       Tensor* out;
110       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
111                               {0, 1}, 0, input_0.shape(), &out));
112       functor::BinaryFunctor<Device, Functor, 1>()(
113           eigen_device, out->template flat<Tout>(),
114           input_0.template flat<Tin>(), input_1.template flat<Tin>(),
115           error_ptr);
116       if (Functor::has_errors && error) {
117         SetComputeError(ctx);
118       }
119       return;
120     } else if (input_0.shape().dims() == 0) {
121       // scalar op tensor.
122       Tensor* out;
123       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
124                               {1}, 0, input_1.shape(), &out));
125 
126       functor::BinaryFunctor<Device, Functor, 1>().Left(
127           eigen_device, out->template flat<Tout>(),
128           input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
129           error_ptr);
130       if (Functor::has_errors && error) {
131         SetComputeError(ctx);
132       }
133       return;
134     } else if (input_1.shape().dims() == 0) {
135       // tensor op scalar.
136       Tensor* out;
137       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
138                               {0}, 0, input_0.shape(), &out));
139       functor::BinaryFunctor<Device, Functor, 1>().Right(
140           eigen_device, out->template flat<Tout>(),
141           input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
142           error_ptr);
143       if (Functor::has_errors && error) {
144         SetComputeError(ctx);
145       }
146       return;
147     }
148 
149     // 'state': Shared helper not dependent on T to reduce code size
150     BinaryOpState state(ctx);
151     if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
152       // Stop when BinaryOpState's constructor failed due to OOM.
153       return;
154     }
155     auto& bcast = state.bcast;
156     Tensor* out = state.out;
157     if (!bcast.IsValid()) {
158       if (ctx->status().ok()) {
159         if (state.result) {
160           functor::SetOneFunctor<Device, bool>()(eigen_device,
161                                                  out->flat<bool>());
162         } else {
163           functor::SetZeroFunctor<Device, bool>()(eigen_device,
164                                                   out->flat<bool>());
165         }
166       }
167       return;
168     }
169 
170     auto& in0 = state.in0;
171     auto& in1 = state.in1;
172     if (state.out_num_elements == 0) {
173       return;
174     }
175 
176     const int ndims = state.ndims;
177     if (ndims <= 1) {
178       auto out_flat = out->flat<Tout>();
179       if (state.in1_num_elements == 1) {
180         // tensor op scalar
181         functor::BinaryFunctor<Device, Functor, 1>().Right(
182             eigen_device, out_flat, in0.template flat<Tin>(),
183             in1.template scalar<Tin>(), error_ptr);
184       } else if (state.in0_num_elements == 1) {
185         // scalar op tensor
186         functor::BinaryFunctor<Device, Functor, 1>().Left(
187             eigen_device, out_flat, in0.template scalar<Tin>(),
188             in1.template flat<Tin>(), error_ptr);
189       } else {
190         functor::BinaryFunctor<Device, Functor, 1>()(
191             eigen_device, out_flat, in0.template flat<Tin>(),
192             in1.template flat<Tin>(), error_ptr);
193       }
194     } else if (ndims == 2) {
195       functor::BinaryFunctor<Device, Functor, 2>().BCast(
196           eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
197           in0.template shaped<Tin, 2>(bcast.x_reshape()),
198           BCast::ToIndexArray<2>(bcast.x_bcast()),
199           in1.template shaped<Tin, 2>(bcast.y_reshape()),
200           BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
201     } else if (ndims == 3) {
202       functor::BinaryFunctor<Device, Functor, 3>().BCast(
203           eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
204           in0.template shaped<Tin, 3>(bcast.x_reshape()),
205           BCast::ToIndexArray<3>(bcast.x_bcast()),
206           in1.template shaped<Tin, 3>(bcast.y_reshape()),
207           BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
208     } else if (ndims == 4) {
209       functor::BinaryFunctor<Device, Functor, 4>().BCast(
210           eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
211           in0.template shaped<Tin, 4>(bcast.x_reshape()),
212           BCast::ToIndexArray<4>(bcast.x_bcast()),
213           in1.template shaped<Tin, 4>(bcast.y_reshape()),
214           BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
215     } else if (ndims == 5) {
216       functor::BinaryFunctor<Device, Functor, 5>().BCast(
217           eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
218           in0.template shaped<Tin, 5>(bcast.x_reshape()),
219           BCast::ToIndexArray<5>(bcast.x_bcast()),
220           in1.template shaped<Tin, 5>(bcast.y_reshape()),
221           BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
222     } else {
223       SetUnimplementedError(ctx);
224     }
225     if (Functor::has_errors && error) {
226       SetComputeError(ctx);
227     }
228   }
229 };
230 
231 template <typename Device, typename T>
232 class ApproximateEqualOp : public OpKernel {
233  public:
ApproximateEqualOp(OpKernelConstruction * context)234   explicit ApproximateEqualOp(OpKernelConstruction* context)
235       : OpKernel(context) {
236     float tolerance;
237     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
238     tolerance_ = T(tolerance);
239   }
Compute(OpKernelContext * context)240   void Compute(OpKernelContext* context) override {
241     const Tensor& x_input = context->input(0);
242     const Tensor& y_input = context->input(1);
243     OP_REQUIRES(
244         context, x_input.shape() == y_input.shape(),
245         errors::InvalidArgument("x and y must be of the same shape. ",
246                                 "x shape: ", x_input.shape().DebugString(),
247                                 ". y shape: ", y_input.shape().DebugString()));
248     Tensor* z_output = nullptr;
249     OP_REQUIRES_OK(context,
250                    context->allocate_output(0, x_input.shape(), &z_output));
251     const Device& d = context->eigen_device<Device>();
252     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
253     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
254     typename TTypes<bool>::Flat z(z_output->flat<bool>());
255     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
256   }
257 
258  private:
259   T tolerance_;
260 };
261 
262 // Basic coefficient-wise binary operations that are known to not require
263 // any broadcasting. This is the case for example of the gradients of
264 // unary operations.
265 //   Device: E.g., CPUDevice, GPUDevice.
266 //   Functor: defined above. E.g., functor::tanh_grad.
267 template <typename Device, typename Functor>
268 class SimpleBinaryOp : public OpKernel {
269  public:
270   typedef typename Functor::in_type Tin;    // Input scalar data type.
271   typedef typename Functor::out_type Tout;  // Output scalar data type.
272 
SimpleBinaryOp(OpKernelConstruction * ctx)273   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
274 
Compute(OpKernelContext * ctx)275   void Compute(OpKernelContext* ctx) override {
276     const Tensor& in0 = ctx->input(0);
277     const Tensor& in1 = ctx->input(1);
278     OP_REQUIRES(
279         ctx, in0.NumElements() == in1.NumElements(),
280         errors::InvalidArgument("The two arguments to a cwise op must have "
281                                 "same number of elements, got ",
282                                 in0.NumElements(), " and ", in1.NumElements()));
283     auto in0_flat = in0.flat<Tin>();
284     auto in1_flat = in1.flat<Tin>();
285     const Device& eigen_device = ctx->eigen_device<Device>();
286 
287     Tensor* out = nullptr;
288     if (std::is_same<Tin, Tout>::value) {
289       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
290                               {0, 1}, 0, in0.shape(), &out));
291     } else {
292       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
293     }
294     auto out_flat = out->flat<Tout>();
295     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
296                                                     in0_flat, in1_flat);
297   }
298 };
299 
300 // Coefficient-wise unary operations:
301 //   Device: E.g., CPUDevice, GPUDevice.
302 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
303 template <typename Device, typename Functor>
304 class UnaryOp : public OpKernel {
305  public:
306   typedef typename Functor::in_type Tin;    // Input scalar data type.
307   typedef typename Functor::out_type Tout;  // Output scalar data type.
308   // Tin may be different from Tout. E.g., abs: complex64 -> float
309 
UnaryOp(OpKernelConstruction * ctx)310   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
311     auto in = DataTypeToEnum<Tin>::v();
312     auto out = DataTypeToEnum<Tout>::v();
313     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
314   }
315 
Compute(OpKernelContext * ctx)316   void Compute(OpKernelContext* ctx) override {
317     const Tensor& inp = ctx->input(0);
318     Tensor* out = nullptr;
319     if (std::is_same<Tin, Tout>::value) {
320       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
321                               {0}, 0, inp.shape(), &out));
322     } else {
323       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
324     }
325     functor::UnaryFunctor<Device, Functor>()(
326         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
327   }
328 };
329 
330 template <typename Device, VariantUnaryOp OpEnum>
331 class UnaryVariantOp : public OpKernel {
332  public:
UnaryVariantOp(OpKernelConstruction * ctx)333   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
334 
Compute(OpKernelContext * ctx)335   void Compute(OpKernelContext* ctx) override {
336     const Tensor& inp = ctx->input(0);
337     OP_REQUIRES(
338         ctx, TensorShapeUtils::IsScalar(inp.shape()),
339         errors::InvalidArgument("Non-scalar variants are not supported."));
340     const Variant& v = inp.scalar<Variant>()();
341     Variant v_out;
342     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
343     int numa_node = ctx->device()->NumaNode();
344     Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
345     out.scalar<Variant>()() = std::move(v_out);
346     ctx->set_output(0, std::move(out));
347   }
348 };
349 
350 namespace functor {
351 
352 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)353 void Assign(const D& d, Out out, Rhs rhs) {
354   out.device(d) = rhs;
355 }
356 
357 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
358 // for functors with no error checking.
359 template <typename Functor, int NDIMS>
360 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
361   void operator()(const CPUDevice& d, typename Functor::tout_type out,
362                   typename Functor::tin_type in0,
363                   typename Functor::tin_type in1, bool* error) {
364     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
365   }
366 
367   void Left(const CPUDevice& d, typename Functor::tout_type out,
368             typename Functor::tscalar_type scalar,
369             typename Functor::tin_type in, bool* error) {
370     typedef typename Functor::out_type Tout;
371     typedef typename Functor::in_type Tin;
372     typedef typename Functor::func Binary;
373     typedef
374         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
375                                               /*is_scalar_in_host_memory=*/true>
376             Unary;
377     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
378   }
379 
380   void Right(const CPUDevice& d, typename Functor::tout_type out,
381              typename Functor::tin_type in,
382              typename Functor::tscalar_type scalar, bool* error) {
383     typedef typename Functor::out_type Tout;
384     typedef typename Functor::in_type Tin;
385     typedef typename Functor::func Binary;
386     typedef typename Eigen::internal::scalar_right<
387         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
388         Unary;
389     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
390   }
391 
392   void BCast(const CPUDevice& dev,
393              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
394              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
395              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
396              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
397              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
398              bool* error) {
399     typename Functor::func func;
400     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
401       Assign(dev, out, in0.binaryExpr(in1, func));
402     } else if (AllOne<NDIMS>(bcast0)) {
403       auto rhs = in1.broadcast(bcast1);
404       Assign(dev, out, in0.binaryExpr(rhs, func));
405     } else if (AllOne<NDIMS>(bcast1)) {
406       auto lhs = in0.broadcast(bcast0);
407       Assign(dev, out, lhs.binaryExpr(in1, func));
408     } else {
409       auto lhs = in0.broadcast(bcast0);
410       auto rhs = in1.broadcast(bcast1);
411       Assign(dev, out, lhs.binaryExpr(rhs, func));
412     }
413   }
414 };
415 
416 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
417 // for functors with no error checking.
418 template <typename Functor>
419 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
420   enum { NDIMS = 2 };
421 
422   void operator()(const CPUDevice& d, typename Functor::tout_type out,
423                   typename Functor::tin_type in0,
424                   typename Functor::tin_type in1, bool* error) {
425     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
426   }
427 
428   void Left(const CPUDevice& d, typename Functor::tout_type out,
429             typename Functor::tscalar_type scalar,
430             typename Functor::tin_type in, bool* error) {
431     typedef typename Functor::out_type Tout;
432     typedef typename Functor::in_type Tin;
433     typedef typename Functor::func Binary;
434     typedef
435         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
436                                               /*is_scalar_in_host_memory=*/true>
437             Unary;
438     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
439   }
440 
441   void Right(const CPUDevice& d, typename Functor::tout_type out,
442              typename Functor::tin_type in,
443              typename Functor::tscalar_type scalar, bool* error) {
444     typedef typename Functor::out_type Tout;
445     typedef typename Functor::in_type Tin;
446     typedef typename Functor::func Binary;
447     typedef typename Eigen::internal::scalar_right<
448         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
449         Unary;
450     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
451   }
452 
453   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
454     Eigen::IndexList<int, Eigen::type2index<1>> ret;
455     ret.set(0, n);
456     return ret;
457   }
458   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
459     Eigen::IndexList<Eigen::type2index<1>, int> ret;
460     ret.set(1, m);
461     return ret;
462   }
463 
464   void BCast(const CPUDevice& dev,
465              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
466              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
467              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
468              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
469              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
470              bool* error) {
471     typedef typename Functor::in_type T;
472     typename Functor::func func;
473     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
474       // Optimize for speed by using Eigen::type2index and avoid
475       // .broadcast() when we know it's a no-op.
476       //
477       // Here, we need to handle 6 cases depending on how many "1"
478       // exist in in0 and in1's shapes (4 numbers in total). It's not
479       // possible that two shapes have more than 2 1s because those
480       // are simplified to NDIMS==1 case.
481       //
482       // Because this optimization increases the binary size for each
483       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
484       // we only apply such optimization for selected ops/types/ndims.
485       //
486       // Because NDIMS, Functor::use_broadcast_optimization and
487       // use_broadcast_optimization<T> are compile-time constant, gcc
488       // does a decent job avoiding generating code when conditions
489       // are not met.
490       const int a = in0.dimension(0);  // in0 is shape [a, b]
491       const int b = in0.dimension(1);
492       const int c = in1.dimension(0);  // in1 is shape [c, d]
493       const int d = in1.dimension(1);
494       if ((a == 1) && (d == 1)) {
495         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
496         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
497         Assign(dev, out, lhs.binaryExpr(rhs, func));
498         return;
499       }
500       if ((b == 1) && (c == 1)) {
501         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
502         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
503         Assign(dev, out, lhs.binaryExpr(rhs, func));
504         return;
505       }
506       if (a == 1) {
507         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
508         auto rhs = in1;
509         Assign(dev, out, lhs.binaryExpr(rhs, func));
510         return;
511       }
512       if (b == 1) {
513         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
514         auto rhs = in1;
515         Assign(dev, out, lhs.binaryExpr(rhs, func));
516         return;
517       }
518       if (c == 1) {
519         auto lhs = in0;
520         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
521         Assign(dev, out, lhs.binaryExpr(rhs, func));
522         return;
523       }
524       if (d == 1) {
525         auto lhs = in0;
526         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
527         Assign(dev, out, lhs.binaryExpr(rhs, func));
528         return;
529       }
530 
531       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
532       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
533       if (bcast0_all_one && !bcast1_all_one) {
534         auto lhs = in0;  // No need to do broadcast for in0
535         auto rhs = in1.broadcast(bcast1);
536         Assign(dev, out, lhs.binaryExpr(rhs, func));
537         return;
538       }
539 
540       if (!bcast0_all_one && bcast1_all_one) {
541         auto lhs = in0.broadcast(bcast0);
542         auto rhs = in1;  // No need to do broadcast for in1
543         Assign(dev, out, lhs.binaryExpr(rhs, func));
544         return;
545       }
546     }
547 
548     // Fallback path. Always works and probably slower.
549     auto lhs = in0.broadcast(bcast0);
550     auto rhs = in1.broadcast(bcast1);
551     Assign(dev, out, lhs.binaryExpr(rhs, func));
552   }
553 };
554 
555 // Version of BinaryFunctor with error handling.
556 template <typename Functor, int NDIMS>
557 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
558   void operator()(const CPUDevice& d, typename Functor::tout_type out,
559                   typename Functor::tin_type in0,
560                   typename Functor::tin_type in1, bool* error) {
561     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
562   }
563 
564   void Left(const CPUDevice& d, typename Functor::tout_type out,
565             typename Functor::tscalar_type scalar,
566             typename Functor::tin_type in, bool* error) {
567     typedef typename Functor::out_type Tout;
568     typedef typename Functor::in_type Tin;
569     typedef typename Functor::func Binary;
570     typedef
571         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
572                                               /*is_scalar_in_host_memory=*/true>
573             Unary;
574     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
575   }
576 
577   void Right(const CPUDevice& d, typename Functor::tout_type out,
578              typename Functor::tin_type in,
579              typename Functor::tscalar_type scalar, bool* error) {
580     typedef typename Functor::out_type Tout;
581     typedef typename Functor::in_type Tin;
582     typedef typename Functor::func Binary;
583     typedef typename Eigen::internal::scalar_right<
584         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
585         Unary;
586     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
587   }
588 
589   void BCast(const CPUDevice& dev,
590              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
591              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
592              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
593              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
594              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
595              bool* error) {
596     typename Functor::func func(error);
597     auto lhs = in0.broadcast(bcast0);
598     auto rhs = in1.broadcast(bcast1);
599     Assign(dev, out, lhs.binaryExpr(rhs, func));
600   }
601 };
602 
603 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
604 template <typename Functor>
605 struct UnaryFunctor<CPUDevice, Functor> {
606   void operator()(const CPUDevice& d, typename Functor::tout_type out,
607                   typename Functor::tin_type in) {
608     Assign(d, out, in.unaryExpr(typename Functor::func()));
609   }
610 };
611 
612 template <typename Functor, typename Targ>
613 struct UnaryFunctorWithArg<CPUDevice, Functor, Targ> {
614   void operator()(const CPUDevice& d, typename Functor::tout_type out,
615                   typename Functor::tin_type in, Targ val) {
616     Assign(d, out, in.unaryExpr(typename Functor::func(val)));
617   }
618 };
619 
620 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
621 template <typename T>
622 struct ApproximateEqual<CPUDevice, T> {
623   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
624                   typename TTypes<T>::ConstFlat y, T tolerance,
625                   typename TTypes<bool>::Flat z) {
626     auto diff = x - y;
627     z.device(d) = diff.abs() <= tolerance;
628   }
629 };
630 
631 }  // end namespace functor
632 
633 #define REGISTER(OP, D, N, F, T)                                             \
634   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
635                           OP<D##Device, F<T>>);
636 
637 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
638   REGISTER_KERNEL_BUILDER(                                     \
639       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
640       OP<D##Device, ENUM>);
641 
642 // Macros to register kernels for multiple types (T0, T1, etc.)  on
643 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
644 // the functor "F" (e.g., functor::sqrt).
645 
646 #if defined(__ANDROID_TYPES_SLIM__)
647 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
648 // Normally Android TensorFlow is built with a reduced number of types (float).
649 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
650 // to generate a library with full type support with a consequent increase in
651 // code size.
652 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
653 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
654 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
655 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
656 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
657 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
658   REGISTER(OP, D, N, F, T0)
659 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
660   REGISTER(OP, D, N, F, T0)
661 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
662   REGISTER(OP, D, N, F, T0)
663 #else  // !defined(__ANDROID_TYPES_SLIM__)
664 #define REGISTER2(OP, D, N, F, T0, T1) \
665   REGISTER(OP, D, N, F, T0)            \
666   REGISTER(OP, D, N, F, T1)
667 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
668   REGISTER2(OP, D, N, F, T0, T1)           \
669   REGISTER(OP, D, N, F, T2)
670 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
671   REGISTER2(OP, D, N, F, T0, T1)               \
672   REGISTER2(OP, D, N, F, T2, T3)
673 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
674   REGISTER3(OP, D, N, F, T0, T1, T2)               \
675   REGISTER2(OP, D, N, F, T3, T4)
676 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
677   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
678   REGISTER3(OP, D, N, F, T3, T4, T5)
679 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
680   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
681   REGISTER3(OP, D, N, F, T4, T5, T6)
682 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
683   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
684   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
685 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
686   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
687   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
688 
689 // Instead of adding REGISTER10, etc., shard the .cc files - see
690 // cwise_op_equal_to_*.cc for an example.
691 
692 #endif  // defined(__ANDROID_TYPES_SLIM__)
693 
694 }  // end namespace tensorflow
695 
696 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
697