xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/binary_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Native XLA implementations of simple binary Ops
17 
18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
19 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/math.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // A subclass of a XlaBinaryOp must build the computation that
37 // describes the (tensor,tensor)->tensor function to apply to each element of
38 // the input.
39 #define XLA_MAKE_BINARY(NAME, HLO)                                         \
40   class NAME##Op : public XlaBinaryOp {                                    \
41    public:                                                                 \
42     explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {}     \
43     xla::XlaOp Computation(                                                \
44         XlaOpKernelContext* ctx, const xla::XlaOp& lhs,                    \
45         const absl::Span<const int64_t>& lhs_shape, const xla::XlaOp& rhs, \
46         const absl::Span<const int64_t>& rhs_shape,                        \
47         const BCast& broadcast_helper,                                     \
48         const std::vector<int64_t>& extend_dimensions) override {          \
49       xla::XlaBuilder* b = ctx->builder();                                 \
50       (void)b;                                                             \
51       (void)lhs_shape;                                                     \
52       (void)rhs_shape;                                                     \
53       (void)extend_dimensions;                                             \
54       return HLO;                                                          \
55     }                                                                      \
56   };                                                                       \
57   REGISTER_XLA_OP(Name(#NAME), NAME##Op)
58 
59 XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
60 XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions));
61 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
62 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
63 XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
64 
65 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
66 XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
67 
68 // Implementation of DivNoNan. Pseudo-code:
69 // if (y == 0) {
70 //   return 0
71 // } else {
72 //   return x / y;
73 // }
DivNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)74 static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
75                                xla::XlaOp y, const BCast& broadcast_helper) {
76   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
77   auto zero = XlaHelpers::Zero(b, dtype);
78   auto y_equals_0 = xla::Eq(y, zero);
79   auto zeros = xla::ZerosLike(x);
80   auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
81   return result;
82 }
83 XLA_MAKE_BINARY(DivNoNan,
84                 DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
85 
86 // Implementation of MulNoNan. Pseudo-code:
87 // if (y == 0) {
88 //   return 0
89 // } else {
90 //   return x * y;
91 // }
MulNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)92 static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
93                                xla::XlaOp y, const BCast& broadcast_helper) {
94   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
95   auto zero = XlaHelpers::Zero(b, dtype);
96   auto y_equals_0 = xla::Eq(y, zero);
97   auto zeros = xla::ZerosLike(x);
98   auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
99   return result;
100 }
101 XLA_MAKE_BINARY(MulNoNan,
102                 MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
103 
104 // Implementation of FloorDiv.
105 //
106 // For floating-point values, simply returns floor(x / y).  For integers, does:
107 //
108 // z = x / y
109 // if (z * y != x && (x < 0) != (y < 0)) {
110 //   return  z - 1;
111 // } else {
112 //   return z;
113 // }
FloorDivImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)114 static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
115                                xla::XlaOp y, const BCast& broadcast_helper) {
116   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
117   if (DataTypeIsFloating(dtype)) {
118     if (dtype == DataType::DT_BFLOAT16) {
119       // The result of a BF16 division may produce the Ceil of what was
120       // computed by F32 division, so avoid end user confusion by doing the
121       // intermediate divide in F32.
122       return xla::ConvertElementType(
123           xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32),
124                               xla::ConvertElementType(y, xla::F32))),
125           xla::BF16);
126     } else {
127       return xla::Floor(xla::Div(x, y));
128     }
129   }
130   if (DataTypeIsUnsigned(dtype)) {
131     return xla::Div(x, y);
132   }
133   auto zero = XlaHelpers::Zero(b, dtype);
134   auto one = XlaHelpers::One(b, dtype);
135   auto x_div_y = xla::Div(x, y);
136   auto round_down = xla::And(xla::Ne(xla::Mul(x_div_y, y), x),
137                              xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)));
138   return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y);
139 }
140 XLA_MAKE_BINARY(FloorDiv,
141                 FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
142 
XlogyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)143 xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
144                      const BCast& broadcast_helper) {
145   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
146   auto zero = xla::ZerosLike(x);
147   auto is_zero = xla::Eq(x, zero);
148   return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
149 }
150 XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
151 
Xlog1pyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)152 xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
153                        const BCast& broadcast_helper) {
154   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
155   auto non_zero = xla::Mul(x, xla::Log1p(y));
156   auto zero = xla::ZerosLike(non_zero);
157   auto x_is_zero = xla::Eq(x, zero);
158   return xla::Select(x_is_zero, zero, non_zero);
159 }
160 XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper));
161 
XdivyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)162 xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y,
163                      const BCast& broadcast_helper) {
164   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
165   auto zero = xla::ZerosLike(x);
166   auto is_zero = xla::Eq(x, zero);
167   return xla::Select(is_zero, zero, xla::Div(x, y));
168 }
169 XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper));
170 
171 // Implementation of FloorMod. Pseudo-code:
172 // T trunc_mod = std::fmod(x, y);
173 // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
174 //                                                   : trunc_mod;
FloorModImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)175 static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
176                                xla::XlaOp y, const BCast& broadcast_helper) {
177   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
178   auto zero = XlaHelpers::Zero(b, dtype);
179   auto trunc_mod = xla::Rem(x, y);
180   auto trunc_mod_not_zero = xla::Ne(trunc_mod, zero);
181   auto do_plus = xla::And(xla::Ne(xla::Lt(trunc_mod, zero), xla::Lt(y, zero)),
182                           trunc_mod_not_zero);
183   return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod);
184 }
185 XLA_MAKE_BINARY(FloorMod,
186                 FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
187 
188 XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
189 XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
190 XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
191 
192 XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
193 XLA_MAKE_BINARY(RightShift,
194                 (DataTypeIsUnsigned(ctx->input_type(0))
195                      ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
196                      : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
197 
198 XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
199 XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
200 XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
201 XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
202 XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
203 XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
204 XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
205 XLA_MAKE_BINARY(
206     RsqrtGrad,
207     xla::Mul((lhs * lhs) * lhs,
208              xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
209              extend_dimensions));
210 XLA_MAKE_BINARY(
211     SqrtGrad,
212     xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
213              lhs, extend_dimensions));
214 
215 XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
216 XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
217 
218 // Comparison ops
219 XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
220 XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
221 XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
222 XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
223 XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
224 XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
225 
226 // Non-linear ops
227 XLA_MAKE_BINARY(SigmoidGrad,
228                 xla::Mul(xla::Mul(rhs, lhs),
229                          xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
230 
231 XLA_MAKE_BINARY(SoftplusGrad, xla::Mul(lhs, xla::Logistic(rhs)));
232 
233 // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
234 XLA_MAKE_BINARY(SoftsignGrad,
235                 xla::Div(lhs,
236                          xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)),
237                                               xla::Abs(rhs)))));
238 
239 XLA_MAKE_BINARY(TanhGrad,
240                 xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
241                                        xla::Mul(lhs, lhs))));
242 
243 XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
244 
SquaredDifferenceImpl(DataType dtype,xla::XlaOp x,xla::XlaOp y,const std::vector<int64_t> & extend_dimensions)245 xla::XlaOp SquaredDifferenceImpl(
246     DataType dtype, xla::XlaOp x, xla::XlaOp y,
247     const std::vector<int64_t>& extend_dimensions) {
248   auto difference = xla::Sub(x, y, extend_dimensions);
249   if (DataTypeIsComplex(dtype)) {
250     return xla::Conj(difference) * difference;
251   } else {
252     return xla::Square(difference);
253   }
254 }
255 XLA_MAKE_BINARY(SquaredDifference,
256                 SquaredDifferenceImpl(input_type(0), lhs, rhs,
257                                       extend_dimensions));
258 
IgammaImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)259 xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y,
260                       const BCast& broadcast_helper) {
261   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
262   return xla::Igamma(x, y);
263 }
264 
265 XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper));
266 
IgammaGradAImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)267 xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y,
268                            const BCast& broadcast_helper) {
269   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
270   return xla::IgammaGradA(x, y);
271 }
272 
273 XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper));
274 
RandomGammaGradImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)275 xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y,
276                                const BCast& broadcast_helper) {
277   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
278   return xla::RandomGammaGrad(x, y);
279 }
280 
281 XLA_MAKE_BINARY(RandomGammaGrad,
282                 RandomGammaGradImpl(lhs, rhs, broadcast_helper));
283 
IgammacImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)284 xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
285                        const BCast& broadcast_helper) {
286   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
287   return xla::Igammac(x, y);
288 }
289 
290 XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
291 
PolygammaImpl(xla::XlaOp n,xla::XlaOp x,const BCast & broadcast_helper)292 xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x,
293                          const BCast& broadcast_helper) {
294   std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper);
295   return xla::Polygamma(n, x);
296 }
297 
298 XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
299 
ZetaImpl(xla::XlaOp x,xla::XlaOp q,const BCast & broadcast_helper)300 xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
301   std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
302   return xla::Zeta(x, q);
303 }
304 
305 XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
306 
307 #undef XLA_MAKE_BINARY
308 
309 class ApproximateEqualOp : public XlaOpKernel {
310  public:
ApproximateEqualOp(OpKernelConstruction * ctx)311   explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
312     OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
313   }
314 
315   // Computes the max of the scalar input x and 0.
Compile(XlaOpKernelContext * ctx)316   void Compile(XlaOpKernelContext* ctx) override {
317     xla::XlaBuilder* b = ctx->builder();
318     auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
319     auto abs_shape = b->GetShape(abs);
320     OP_REQUIRES_OK(ctx, abs_shape.status());
321     auto abs_type = abs_shape.ValueOrDie().element_type();
322     auto result =
323         xla::Lt(abs, xla::ConvertElementType(
324                          xla::ConstantR0<float>(b, tolerance_), abs_type));
325     ctx->SetOutput(0, result);
326   }
327 
328  private:
329   float tolerance_;
330 };
331 REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
332 
333 }  // namespace
334 }  // namespace tensorflow
335