1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Native XLA implementations of simple binary Ops
17
18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
19 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/math.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // A subclass of a XlaBinaryOp must build the computation that
37 // describes the (tensor,tensor)->tensor function to apply to each element of
38 // the input.
39 #define XLA_MAKE_BINARY(NAME, HLO) \
40 class NAME##Op : public XlaBinaryOp { \
41 public: \
42 explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
43 xla::XlaOp Computation( \
44 XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
45 const absl::Span<const int64_t>& lhs_shape, const xla::XlaOp& rhs, \
46 const absl::Span<const int64_t>& rhs_shape, \
47 const BCast& broadcast_helper, \
48 const std::vector<int64_t>& extend_dimensions) override { \
49 xla::XlaBuilder* b = ctx->builder(); \
50 (void)b; \
51 (void)lhs_shape; \
52 (void)rhs_shape; \
53 (void)extend_dimensions; \
54 return HLO; \
55 } \
56 }; \
57 REGISTER_XLA_OP(Name(#NAME), NAME##Op)
58
59 XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
60 XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions));
61 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
62 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
63 XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
64
65 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
66 XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
67
68 // Implementation of DivNoNan. Pseudo-code:
69 // if (y == 0) {
70 // return 0
71 // } else {
72 // return x / y;
73 // }
DivNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)74 static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
75 xla::XlaOp y, const BCast& broadcast_helper) {
76 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
77 auto zero = XlaHelpers::Zero(b, dtype);
78 auto y_equals_0 = xla::Eq(y, zero);
79 auto zeros = xla::ZerosLike(x);
80 auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
81 return result;
82 }
83 XLA_MAKE_BINARY(DivNoNan,
84 DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
85
86 // Implementation of MulNoNan. Pseudo-code:
87 // if (y == 0) {
88 // return 0
89 // } else {
90 // return x * y;
91 // }
MulNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)92 static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
93 xla::XlaOp y, const BCast& broadcast_helper) {
94 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
95 auto zero = XlaHelpers::Zero(b, dtype);
96 auto y_equals_0 = xla::Eq(y, zero);
97 auto zeros = xla::ZerosLike(x);
98 auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
99 return result;
100 }
101 XLA_MAKE_BINARY(MulNoNan,
102 MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
103
104 // Implementation of FloorDiv.
105 //
106 // For floating-point values, simply returns floor(x / y). For integers, does:
107 //
108 // z = x / y
109 // if (z * y != x && (x < 0) != (y < 0)) {
110 // return z - 1;
111 // } else {
112 // return z;
113 // }
FloorDivImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)114 static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
115 xla::XlaOp y, const BCast& broadcast_helper) {
116 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
117 if (DataTypeIsFloating(dtype)) {
118 if (dtype == DataType::DT_BFLOAT16) {
119 // The result of a BF16 division may produce the Ceil of what was
120 // computed by F32 division, so avoid end user confusion by doing the
121 // intermediate divide in F32.
122 return xla::ConvertElementType(
123 xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32),
124 xla::ConvertElementType(y, xla::F32))),
125 xla::BF16);
126 } else {
127 return xla::Floor(xla::Div(x, y));
128 }
129 }
130 if (DataTypeIsUnsigned(dtype)) {
131 return xla::Div(x, y);
132 }
133 auto zero = XlaHelpers::Zero(b, dtype);
134 auto one = XlaHelpers::One(b, dtype);
135 auto x_div_y = xla::Div(x, y);
136 auto round_down = xla::And(xla::Ne(xla::Mul(x_div_y, y), x),
137 xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)));
138 return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y);
139 }
140 XLA_MAKE_BINARY(FloorDiv,
141 FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
142
XlogyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)143 xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
144 const BCast& broadcast_helper) {
145 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
146 auto zero = xla::ZerosLike(x);
147 auto is_zero = xla::Eq(x, zero);
148 return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
149 }
150 XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
151
Xlog1pyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)152 xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
153 const BCast& broadcast_helper) {
154 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
155 auto non_zero = xla::Mul(x, xla::Log1p(y));
156 auto zero = xla::ZerosLike(non_zero);
157 auto x_is_zero = xla::Eq(x, zero);
158 return xla::Select(x_is_zero, zero, non_zero);
159 }
160 XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper));
161
XdivyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)162 xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y,
163 const BCast& broadcast_helper) {
164 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
165 auto zero = xla::ZerosLike(x);
166 auto is_zero = xla::Eq(x, zero);
167 return xla::Select(is_zero, zero, xla::Div(x, y));
168 }
169 XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper));
170
171 // Implementation of FloorMod. Pseudo-code:
172 // T trunc_mod = std::fmod(x, y);
173 // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
174 // : trunc_mod;
FloorModImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)175 static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
176 xla::XlaOp y, const BCast& broadcast_helper) {
177 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
178 auto zero = XlaHelpers::Zero(b, dtype);
179 auto trunc_mod = xla::Rem(x, y);
180 auto trunc_mod_not_zero = xla::Ne(trunc_mod, zero);
181 auto do_plus = xla::And(xla::Ne(xla::Lt(trunc_mod, zero), xla::Lt(y, zero)),
182 trunc_mod_not_zero);
183 return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod);
184 }
185 XLA_MAKE_BINARY(FloorMod,
186 FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
187
188 XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
189 XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
190 XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
191
192 XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
193 XLA_MAKE_BINARY(RightShift,
194 (DataTypeIsUnsigned(ctx->input_type(0))
195 ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
196 : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
197
198 XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
199 XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
200 XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
201 XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
202 XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
203 XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
204 XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
205 XLA_MAKE_BINARY(
206 RsqrtGrad,
207 xla::Mul((lhs * lhs) * lhs,
208 xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
209 extend_dimensions));
210 XLA_MAKE_BINARY(
211 SqrtGrad,
212 xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
213 lhs, extend_dimensions));
214
215 XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
216 XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
217
218 // Comparison ops
219 XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
220 XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
221 XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
222 XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
223 XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
224 XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
225
226 // Non-linear ops
227 XLA_MAKE_BINARY(SigmoidGrad,
228 xla::Mul(xla::Mul(rhs, lhs),
229 xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
230
231 XLA_MAKE_BINARY(SoftplusGrad, xla::Mul(lhs, xla::Logistic(rhs)));
232
233 // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
234 XLA_MAKE_BINARY(SoftsignGrad,
235 xla::Div(lhs,
236 xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)),
237 xla::Abs(rhs)))));
238
239 XLA_MAKE_BINARY(TanhGrad,
240 xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
241 xla::Mul(lhs, lhs))));
242
243 XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
244
SquaredDifferenceImpl(DataType dtype,xla::XlaOp x,xla::XlaOp y,const std::vector<int64_t> & extend_dimensions)245 xla::XlaOp SquaredDifferenceImpl(
246 DataType dtype, xla::XlaOp x, xla::XlaOp y,
247 const std::vector<int64_t>& extend_dimensions) {
248 auto difference = xla::Sub(x, y, extend_dimensions);
249 if (DataTypeIsComplex(dtype)) {
250 return xla::Conj(difference) * difference;
251 } else {
252 return xla::Square(difference);
253 }
254 }
255 XLA_MAKE_BINARY(SquaredDifference,
256 SquaredDifferenceImpl(input_type(0), lhs, rhs,
257 extend_dimensions));
258
IgammaImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)259 xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y,
260 const BCast& broadcast_helper) {
261 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
262 return xla::Igamma(x, y);
263 }
264
265 XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper));
266
IgammaGradAImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)267 xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y,
268 const BCast& broadcast_helper) {
269 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
270 return xla::IgammaGradA(x, y);
271 }
272
273 XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper));
274
RandomGammaGradImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)275 xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y,
276 const BCast& broadcast_helper) {
277 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
278 return xla::RandomGammaGrad(x, y);
279 }
280
281 XLA_MAKE_BINARY(RandomGammaGrad,
282 RandomGammaGradImpl(lhs, rhs, broadcast_helper));
283
IgammacImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)284 xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
285 const BCast& broadcast_helper) {
286 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
287 return xla::Igammac(x, y);
288 }
289
290 XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
291
PolygammaImpl(xla::XlaOp n,xla::XlaOp x,const BCast & broadcast_helper)292 xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x,
293 const BCast& broadcast_helper) {
294 std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper);
295 return xla::Polygamma(n, x);
296 }
297
298 XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
299
ZetaImpl(xla::XlaOp x,xla::XlaOp q,const BCast & broadcast_helper)300 xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
301 std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
302 return xla::Zeta(x, q);
303 }
304
305 XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
306
307 #undef XLA_MAKE_BINARY
308
309 class ApproximateEqualOp : public XlaOpKernel {
310 public:
ApproximateEqualOp(OpKernelConstruction * ctx)311 explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
312 OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
313 }
314
315 // Computes the max of the scalar input x and 0.
Compile(XlaOpKernelContext * ctx)316 void Compile(XlaOpKernelContext* ctx) override {
317 xla::XlaBuilder* b = ctx->builder();
318 auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
319 auto abs_shape = b->GetShape(abs);
320 OP_REQUIRES_OK(ctx, abs_shape.status());
321 auto abs_type = abs_shape.ValueOrDie().element_type();
322 auto result =
323 xla::Lt(abs, xla::ConvertElementType(
324 xla::ConstantR0<float>(b, tolerance_), abs_type));
325 ctx->SetOutput(0, result);
326 }
327
328 private:
329 float tolerance_;
330 };
331 REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
332
333 } // namespace
334 } // namespace tensorflow
335