1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/platform/bfloat16.h"
26
27
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/cwise_ops.h"
33 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/util/bcast.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42
43 class BinaryOpShared : public OpKernel {
44 public:
45 explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
46
47 protected:
48 struct BinaryOpState {
49 // Sets up bcast with the shape of in0 and in1, ensures that the bcast
50 // is valid, and if so, set out, either by allocating a new buffer using
51 // ctx->output(...) or by creating an alias for an owned input buffer for
52 // in-place computation.
53 // Caller must check ctx->status() upon return for non-ok status.
54 // If ctx->status().ok() is true, then out is guaranteed to be allocated.
55 explicit BinaryOpState(OpKernelContext* ctx);
56
57 const Tensor& in0;
58 const Tensor& in1;
59
60 BCast bcast;
61 Tensor* out = nullptr;
62 int64_t out_num_elements;
63
64 int64_t in0_num_elements;
65 int64_t in1_num_elements;
66
67 int ndims;
68 bool result;
69 };
70
71 void SetUnimplementedError(OpKernelContext* ctx);
72 void SetComputeError(OpKernelContext* ctx);
73 };
74
75 // Coefficient-wise binary operations:
76 // Device: E.g., CPUDevice, GPUDevice.
77 // Functor: defined in cwise_ops.h. E.g., functor::add.
78 template <typename Device, typename Functor>
79 class BinaryOp : public BinaryOpShared {
80 public:
81 typedef typename Functor::in_type Tin; // Input scalar data type.
82 typedef typename Functor::out_type Tout; // Output scalar data type.
83
BinaryOp(OpKernelConstruction * ctx)84 explicit BinaryOp(OpKernelConstruction* ctx)
85 : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
86 DataTypeToEnum<Tin>::v()) {}
87
Compute(OpKernelContext * ctx)88 void Compute(OpKernelContext* ctx) override {
89 const Tensor& input_0 = ctx->input(0);
90 OP_REQUIRES(ctx, input_0.dtype() == DataTypeToEnum<Tin>::v(),
91 errors::InvalidArgument(
92 "Expected tensor of type ",
93 DataTypeString(DataTypeToEnum<Tin>::v()), " but got type ",
94 DataTypeString(input_0.dtype())));
95 const Tensor& input_1 = ctx->input(1);
96 OP_REQUIRES(ctx, input_1.dtype() == DataTypeToEnum<Tin>::v(),
97 errors::InvalidArgument(
98 "Expected tensor of type ",
99 DataTypeString(DataTypeToEnum<Tin>::v()), " but got type ",
100 DataTypeString(input_1.dtype())));
101 const Device& eigen_device = ctx->eigen_device<Device>();
102 bool error = false;
103 bool* const error_ptr = Functor::has_errors ? &error : nullptr;
104
105 // NOTE: Handle three simple cases before building the BinaryOpState, which
106 // is relatively expensive for small operations.
107 if (input_0.shape() == input_1.shape()) {
108 // tensor op tensor with no broadcasting.
109 Tensor* out;
110 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
111 {0, 1}, 0, input_0.shape(), &out));
112 functor::BinaryFunctor<Device, Functor, 1>()(
113 eigen_device, out->template flat<Tout>(),
114 input_0.template flat<Tin>(), input_1.template flat<Tin>(),
115 error_ptr);
116 if (Functor::has_errors && error) {
117 SetComputeError(ctx);
118 }
119 return;
120 } else if (input_0.shape().dims() == 0) {
121 // scalar op tensor.
122 Tensor* out;
123 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
124 {1}, 0, input_1.shape(), &out));
125
126 functor::BinaryFunctor<Device, Functor, 1>().Left(
127 eigen_device, out->template flat<Tout>(),
128 input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
129 error_ptr);
130 if (Functor::has_errors && error) {
131 SetComputeError(ctx);
132 }
133 return;
134 } else if (input_1.shape().dims() == 0) {
135 // tensor op scalar.
136 Tensor* out;
137 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
138 {0}, 0, input_0.shape(), &out));
139 functor::BinaryFunctor<Device, Functor, 1>().Right(
140 eigen_device, out->template flat<Tout>(),
141 input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
142 error_ptr);
143 if (Functor::has_errors && error) {
144 SetComputeError(ctx);
145 }
146 return;
147 }
148
149 // 'state': Shared helper not dependent on T to reduce code size
150 BinaryOpState state(ctx);
151 if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
152 // Stop when BinaryOpState's constructor failed due to OOM.
153 return;
154 }
155 auto& bcast = state.bcast;
156 Tensor* out = state.out;
157 if (!bcast.IsValid()) {
158 if (ctx->status().ok()) {
159 if (state.result) {
160 functor::SetOneFunctor<Device, bool>()(eigen_device,
161 out->flat<bool>());
162 } else {
163 functor::SetZeroFunctor<Device, bool>()(eigen_device,
164 out->flat<bool>());
165 }
166 }
167 return;
168 }
169
170 auto& in0 = state.in0;
171 auto& in1 = state.in1;
172 if (state.out_num_elements == 0) {
173 return;
174 }
175
176 const int ndims = state.ndims;
177 if (ndims <= 1) {
178 auto out_flat = out->flat<Tout>();
179 if (state.in1_num_elements == 1) {
180 // tensor op scalar
181 functor::BinaryFunctor<Device, Functor, 1>().Right(
182 eigen_device, out_flat, in0.template flat<Tin>(),
183 in1.template scalar<Tin>(), error_ptr);
184 } else if (state.in0_num_elements == 1) {
185 // scalar op tensor
186 functor::BinaryFunctor<Device, Functor, 1>().Left(
187 eigen_device, out_flat, in0.template scalar<Tin>(),
188 in1.template flat<Tin>(), error_ptr);
189 } else {
190 functor::BinaryFunctor<Device, Functor, 1>()(
191 eigen_device, out_flat, in0.template flat<Tin>(),
192 in1.template flat<Tin>(), error_ptr);
193 }
194 } else if (ndims == 2) {
195 functor::BinaryFunctor<Device, Functor, 2>().BCast(
196 eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
197 in0.template shaped<Tin, 2>(bcast.x_reshape()),
198 BCast::ToIndexArray<2>(bcast.x_bcast()),
199 in1.template shaped<Tin, 2>(bcast.y_reshape()),
200 BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
201 } else if (ndims == 3) {
202 functor::BinaryFunctor<Device, Functor, 3>().BCast(
203 eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
204 in0.template shaped<Tin, 3>(bcast.x_reshape()),
205 BCast::ToIndexArray<3>(bcast.x_bcast()),
206 in1.template shaped<Tin, 3>(bcast.y_reshape()),
207 BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
208 } else if (ndims == 4) {
209 functor::BinaryFunctor<Device, Functor, 4>().BCast(
210 eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
211 in0.template shaped<Tin, 4>(bcast.x_reshape()),
212 BCast::ToIndexArray<4>(bcast.x_bcast()),
213 in1.template shaped<Tin, 4>(bcast.y_reshape()),
214 BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
215 } else if (ndims == 5) {
216 functor::BinaryFunctor<Device, Functor, 5>().BCast(
217 eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
218 in0.template shaped<Tin, 5>(bcast.x_reshape()),
219 BCast::ToIndexArray<5>(bcast.x_bcast()),
220 in1.template shaped<Tin, 5>(bcast.y_reshape()),
221 BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
222 } else {
223 SetUnimplementedError(ctx);
224 }
225 if (Functor::has_errors && error) {
226 SetComputeError(ctx);
227 }
228 }
229 };
230
231 template <typename Device, typename T>
232 class ApproximateEqualOp : public OpKernel {
233 public:
ApproximateEqualOp(OpKernelConstruction * context)234 explicit ApproximateEqualOp(OpKernelConstruction* context)
235 : OpKernel(context) {
236 float tolerance;
237 OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
238 tolerance_ = T(tolerance);
239 }
Compute(OpKernelContext * context)240 void Compute(OpKernelContext* context) override {
241 const Tensor& x_input = context->input(0);
242 const Tensor& y_input = context->input(1);
243 OP_REQUIRES(
244 context, x_input.shape() == y_input.shape(),
245 errors::InvalidArgument("x and y must be of the same shape. ",
246 "x shape: ", x_input.shape().DebugString(),
247 ". y shape: ", y_input.shape().DebugString()));
248 Tensor* z_output = nullptr;
249 OP_REQUIRES_OK(context,
250 context->allocate_output(0, x_input.shape(), &z_output));
251 const Device& d = context->eigen_device<Device>();
252 typename TTypes<T>::ConstFlat x(x_input.flat<T>());
253 typename TTypes<T>::ConstFlat y(y_input.flat<T>());
254 typename TTypes<bool>::Flat z(z_output->flat<bool>());
255 functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
256 }
257
258 private:
259 T tolerance_;
260 };
261
262 // Basic coefficient-wise binary operations that are known to not require
263 // any broadcasting. This is the case for example of the gradients of
264 // unary operations.
265 // Device: E.g., CPUDevice, GPUDevice.
266 // Functor: defined above. E.g., functor::tanh_grad.
267 template <typename Device, typename Functor>
268 class SimpleBinaryOp : public OpKernel {
269 public:
270 typedef typename Functor::in_type Tin; // Input scalar data type.
271 typedef typename Functor::out_type Tout; // Output scalar data type.
272
SimpleBinaryOp(OpKernelConstruction * ctx)273 explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
274
Compute(OpKernelContext * ctx)275 void Compute(OpKernelContext* ctx) override {
276 const Tensor& in0 = ctx->input(0);
277 const Tensor& in1 = ctx->input(1);
278 OP_REQUIRES(
279 ctx, in0.NumElements() == in1.NumElements(),
280 errors::InvalidArgument("The two arguments to a cwise op must have "
281 "same number of elements, got ",
282 in0.NumElements(), " and ", in1.NumElements()));
283 auto in0_flat = in0.flat<Tin>();
284 auto in1_flat = in1.flat<Tin>();
285 const Device& eigen_device = ctx->eigen_device<Device>();
286
287 Tensor* out = nullptr;
288 if (std::is_same<Tin, Tout>::value) {
289 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
290 {0, 1}, 0, in0.shape(), &out));
291 } else {
292 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
293 }
294 auto out_flat = out->flat<Tout>();
295 functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
296 in0_flat, in1_flat);
297 }
298 };
299
300 // Coefficient-wise unary operations:
301 // Device: E.g., CPUDevice, GPUDevice.
302 // Functor: defined in cwise_ops.h. E.g., functor::sqrt.
303 template <typename Device, typename Functor>
304 class UnaryOp : public OpKernel {
305 public:
306 typedef typename Functor::in_type Tin; // Input scalar data type.
307 typedef typename Functor::out_type Tout; // Output scalar data type.
308 // Tin may be different from Tout. E.g., abs: complex64 -> float
309
UnaryOp(OpKernelConstruction * ctx)310 explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
311 auto in = DataTypeToEnum<Tin>::v();
312 auto out = DataTypeToEnum<Tout>::v();
313 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
314 }
315
Compute(OpKernelContext * ctx)316 void Compute(OpKernelContext* ctx) override {
317 const Tensor& inp = ctx->input(0);
318 Tensor* out = nullptr;
319 if (std::is_same<Tin, Tout>::value) {
320 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
321 {0}, 0, inp.shape(), &out));
322 } else {
323 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
324 }
325 functor::UnaryFunctor<Device, Functor>()(
326 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
327 }
328 };
329
330 template <typename Device, VariantUnaryOp OpEnum>
331 class UnaryVariantOp : public OpKernel {
332 public:
UnaryVariantOp(OpKernelConstruction * ctx)333 explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
334
Compute(OpKernelContext * ctx)335 void Compute(OpKernelContext* ctx) override {
336 const Tensor& inp = ctx->input(0);
337 OP_REQUIRES(
338 ctx, TensorShapeUtils::IsScalar(inp.shape()),
339 errors::InvalidArgument("Non-scalar variants are not supported."));
340 const Variant& v = inp.scalar<Variant>()();
341 Variant v_out;
342 OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
343 int numa_node = ctx->device()->NumaNode();
344 Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
345 out.scalar<Variant>()() = std::move(v_out);
346 ctx->set_output(0, std::move(out));
347 }
348 };
349
350 namespace functor {
351
352 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)353 void Assign(const D& d, Out out, Rhs rhs) {
354 out.device(d) = rhs;
355 }
356
357 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
358 // for functors with no error checking.
359 template <typename Functor, int NDIMS>
360 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
361 void operator()(const CPUDevice& d, typename Functor::tout_type out,
362 typename Functor::tin_type in0,
363 typename Functor::tin_type in1, bool* error) {
364 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
365 }
366
367 void Left(const CPUDevice& d, typename Functor::tout_type out,
368 typename Functor::tscalar_type scalar,
369 typename Functor::tin_type in, bool* error) {
370 typedef typename Functor::out_type Tout;
371 typedef typename Functor::in_type Tin;
372 typedef typename Functor::func Binary;
373 typedef
374 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
375 /*is_scalar_in_host_memory=*/true>
376 Unary;
377 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
378 }
379
380 void Right(const CPUDevice& d, typename Functor::tout_type out,
381 typename Functor::tin_type in,
382 typename Functor::tscalar_type scalar, bool* error) {
383 typedef typename Functor::out_type Tout;
384 typedef typename Functor::in_type Tin;
385 typedef typename Functor::func Binary;
386 typedef typename Eigen::internal::scalar_right<
387 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
388 Unary;
389 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
390 }
391
392 void BCast(const CPUDevice& dev,
393 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
394 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
395 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
396 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
397 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
398 bool* error) {
399 typename Functor::func func;
400 if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
401 Assign(dev, out, in0.binaryExpr(in1, func));
402 } else if (AllOne<NDIMS>(bcast0)) {
403 auto rhs = in1.broadcast(bcast1);
404 Assign(dev, out, in0.binaryExpr(rhs, func));
405 } else if (AllOne<NDIMS>(bcast1)) {
406 auto lhs = in0.broadcast(bcast0);
407 Assign(dev, out, lhs.binaryExpr(in1, func));
408 } else {
409 auto lhs = in0.broadcast(bcast0);
410 auto rhs = in1.broadcast(bcast1);
411 Assign(dev, out, lhs.binaryExpr(rhs, func));
412 }
413 }
414 };
415
416 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
417 // for functors with no error checking.
418 template <typename Functor>
419 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
420 enum { NDIMS = 2 };
421
422 void operator()(const CPUDevice& d, typename Functor::tout_type out,
423 typename Functor::tin_type in0,
424 typename Functor::tin_type in1, bool* error) {
425 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
426 }
427
428 void Left(const CPUDevice& d, typename Functor::tout_type out,
429 typename Functor::tscalar_type scalar,
430 typename Functor::tin_type in, bool* error) {
431 typedef typename Functor::out_type Tout;
432 typedef typename Functor::in_type Tin;
433 typedef typename Functor::func Binary;
434 typedef
435 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
436 /*is_scalar_in_host_memory=*/true>
437 Unary;
438 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
439 }
440
441 void Right(const CPUDevice& d, typename Functor::tout_type out,
442 typename Functor::tin_type in,
443 typename Functor::tscalar_type scalar, bool* error) {
444 typedef typename Functor::out_type Tout;
445 typedef typename Functor::in_type Tin;
446 typedef typename Functor::func Binary;
447 typedef typename Eigen::internal::scalar_right<
448 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
449 Unary;
450 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
451 }
452
453 inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
454 Eigen::IndexList<int, Eigen::type2index<1>> ret;
455 ret.set(0, n);
456 return ret;
457 }
458 inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
459 Eigen::IndexList<Eigen::type2index<1>, int> ret;
460 ret.set(1, m);
461 return ret;
462 }
463
464 void BCast(const CPUDevice& dev,
465 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
466 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
467 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
468 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
469 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
470 bool* error) {
471 typedef typename Functor::in_type T;
472 typename Functor::func func;
473 if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
474 // Optimize for speed by using Eigen::type2index and avoid
475 // .broadcast() when we know it's a no-op.
476 //
477 // Here, we need to handle 6 cases depending on how many "1"
478 // exist in in0 and in1's shapes (4 numbers in total). It's not
479 // possible that two shapes have more than 2 1s because those
480 // are simplified to NDIMS==1 case.
481 //
482 // Because this optimization increases the binary size for each
483 // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
484 // we only apply such optimization for selected ops/types/ndims.
485 //
486 // Because NDIMS, Functor::use_broadcast_optimization and
487 // use_broadcast_optimization<T> are compile-time constant, gcc
488 // does a decent job avoiding generating code when conditions
489 // are not met.
490 const int a = in0.dimension(0); // in0 is shape [a, b]
491 const int b = in0.dimension(1);
492 const int c = in1.dimension(0); // in1 is shape [c, d]
493 const int d = in1.dimension(1);
494 if ((a == 1) && (d == 1)) {
495 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
496 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
497 Assign(dev, out, lhs.binaryExpr(rhs, func));
498 return;
499 }
500 if ((b == 1) && (c == 1)) {
501 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
502 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
503 Assign(dev, out, lhs.binaryExpr(rhs, func));
504 return;
505 }
506 if (a == 1) {
507 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
508 auto rhs = in1;
509 Assign(dev, out, lhs.binaryExpr(rhs, func));
510 return;
511 }
512 if (b == 1) {
513 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
514 auto rhs = in1;
515 Assign(dev, out, lhs.binaryExpr(rhs, func));
516 return;
517 }
518 if (c == 1) {
519 auto lhs = in0;
520 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
521 Assign(dev, out, lhs.binaryExpr(rhs, func));
522 return;
523 }
524 if (d == 1) {
525 auto lhs = in0;
526 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
527 Assign(dev, out, lhs.binaryExpr(rhs, func));
528 return;
529 }
530
531 const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
532 const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
533 if (bcast0_all_one && !bcast1_all_one) {
534 auto lhs = in0; // No need to do broadcast for in0
535 auto rhs = in1.broadcast(bcast1);
536 Assign(dev, out, lhs.binaryExpr(rhs, func));
537 return;
538 }
539
540 if (!bcast0_all_one && bcast1_all_one) {
541 auto lhs = in0.broadcast(bcast0);
542 auto rhs = in1; // No need to do broadcast for in1
543 Assign(dev, out, lhs.binaryExpr(rhs, func));
544 return;
545 }
546 }
547
548 // Fallback path. Always works and probably slower.
549 auto lhs = in0.broadcast(bcast0);
550 auto rhs = in1.broadcast(bcast1);
551 Assign(dev, out, lhs.binaryExpr(rhs, func));
552 }
553 };
554
555 // Version of BinaryFunctor with error handling.
556 template <typename Functor, int NDIMS>
557 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
558 void operator()(const CPUDevice& d, typename Functor::tout_type out,
559 typename Functor::tin_type in0,
560 typename Functor::tin_type in1, bool* error) {
561 Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
562 }
563
564 void Left(const CPUDevice& d, typename Functor::tout_type out,
565 typename Functor::tscalar_type scalar,
566 typename Functor::tin_type in, bool* error) {
567 typedef typename Functor::out_type Tout;
568 typedef typename Functor::in_type Tin;
569 typedef typename Functor::func Binary;
570 typedef
571 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
572 /*is_scalar_in_host_memory=*/true>
573 Unary;
574 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
575 }
576
577 void Right(const CPUDevice& d, typename Functor::tout_type out,
578 typename Functor::tin_type in,
579 typename Functor::tscalar_type scalar, bool* error) {
580 typedef typename Functor::out_type Tout;
581 typedef typename Functor::in_type Tin;
582 typedef typename Functor::func Binary;
583 typedef typename Eigen::internal::scalar_right<
584 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
585 Unary;
586 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
587 }
588
589 void BCast(const CPUDevice& dev,
590 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
591 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
592 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
593 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
594 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
595 bool* error) {
596 typename Functor::func func(error);
597 auto lhs = in0.broadcast(bcast0);
598 auto rhs = in1.broadcast(bcast1);
599 Assign(dev, out, lhs.binaryExpr(rhs, func));
600 }
601 };
602
603 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
604 template <typename Functor>
605 struct UnaryFunctor<CPUDevice, Functor> {
606 void operator()(const CPUDevice& d, typename Functor::tout_type out,
607 typename Functor::tin_type in) {
608 Assign(d, out, in.unaryExpr(typename Functor::func()));
609 }
610 };
611
612 template <typename Functor, typename Targ>
613 struct UnaryFunctorWithArg<CPUDevice, Functor, Targ> {
614 void operator()(const CPUDevice& d, typename Functor::tout_type out,
615 typename Functor::tin_type in, Targ val) {
616 Assign(d, out, in.unaryExpr(typename Functor::func(val)));
617 }
618 };
619
620 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
621 template <typename T>
622 struct ApproximateEqual<CPUDevice, T> {
623 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
624 typename TTypes<T>::ConstFlat y, T tolerance,
625 typename TTypes<bool>::Flat z) {
626 auto diff = x - y;
627 z.device(d) = diff.abs() <= tolerance;
628 }
629 };
630
631 } // end namespace functor
632
633 #define REGISTER(OP, D, N, F, T) \
634 REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
635 OP<D##Device, F<T>>);
636
637 #define REGISTER_VARIANT(OP, D, N, ENUM) \
638 REGISTER_KERNEL_BUILDER( \
639 Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
640 OP<D##Device, ENUM>);
641
642 // Macros to register kernels for multiple types (T0, T1, etc.) on
643 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
644 // the functor "F" (e.g., functor::sqrt).
645
646 #if defined(__ANDROID_TYPES_SLIM__)
647 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
648 // Normally Android TensorFlow is built with a reduced number of types (float).
649 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
650 // to generate a library with full type support with a consequent increase in
651 // code size.
652 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
653 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
654 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
655 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
656 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
657 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
658 REGISTER(OP, D, N, F, T0)
659 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
660 REGISTER(OP, D, N, F, T0)
661 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
662 REGISTER(OP, D, N, F, T0)
663 #else // !defined(__ANDROID_TYPES_SLIM__)
664 #define REGISTER2(OP, D, N, F, T0, T1) \
665 REGISTER(OP, D, N, F, T0) \
666 REGISTER(OP, D, N, F, T1)
667 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
668 REGISTER2(OP, D, N, F, T0, T1) \
669 REGISTER(OP, D, N, F, T2)
670 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
671 REGISTER2(OP, D, N, F, T0, T1) \
672 REGISTER2(OP, D, N, F, T2, T3)
673 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
674 REGISTER3(OP, D, N, F, T0, T1, T2) \
675 REGISTER2(OP, D, N, F, T3, T4)
676 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
677 REGISTER3(OP, D, N, F, T0, T1, T2) \
678 REGISTER3(OP, D, N, F, T3, T4, T5)
679 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
680 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
681 REGISTER3(OP, D, N, F, T4, T5, T6)
682 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
683 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
684 REGISTER4(OP, D, N, F, T4, T5, T6, T7)
685 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
686 REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
687 REGISTER4(OP, D, N, F, T5, T6, T7, T8)
688
689 // Instead of adding REGISTER10, etc., shard the .cc files - see
690 // cwise_op_equal_to_*.cc for an example.
691
692 #endif // defined(__ANDROID_TYPES_SLIM__)
693
694 } // end namespace tensorflow
695
696 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
697