1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/math.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/types.h"
25
26 namespace tensorflow {
27 namespace {
28
29 class ResourceApplyGradientDescent : public XlaOpKernel {
30 public:
ResourceApplyGradientDescent(OpKernelConstruction * ctx)31 explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
32 : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)33 void Compile(XlaOpKernelContext* ctx) override {
34 xla::XlaOp handle;
35 DataType type = ctx->input_type(1);
36 TensorShape var_shape;
37 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
38
39 TensorShape alpha_shape = ctx->InputShape(1);
40 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
41 errors::InvalidArgument("alpha is not a scalar: ",
42 alpha_shape.DebugString()));
43
44 TensorShape delta_shape = ctx->InputShape(2);
45 OP_REQUIRES(
46 ctx, var_shape.IsSameSize(delta_shape),
47 errors::InvalidArgument("var and delta do not have the same shape: ",
48 var_shape.DebugString(), " vs ",
49 delta_shape.DebugString()));
50
51 handle = handle - ctx->Input(1) * ctx->Input(2);
52 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
53 }
54 };
55 REGISTER_XLA_OP(Name("ResourceApplyGradientDescent")
56 .TypeConstraint("T", kFloatAndComplexTypes),
57 ResourceApplyGradientDescent);
58
ProximalGradientDescentUpdate(xla::XlaOp var,xla::XlaOp lr,xla::XlaOp l1,xla::XlaOp l2,xla::XlaOp grad)59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
60 xla::XlaOp l1, xla::XlaOp l2,
61 xla::XlaOp grad) {
62 xla::XlaOp one = xla::ScalarLike(lr, 1.0);
63 xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
64 xla::XlaOp prox_var = var - grad * lr;
65 xla::XlaOp l1_gt_zero =
66 xla::Sign(prox_var) * xla::Max(xla::Abs(prox_var) - lr * l1, zero);
67 xla::XlaOp l1_le_zero = prox_var;
68 return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero) /
69 (one + lr * l2);
70 }
71
72 class ResourceApplyProximalGradientDescent : public XlaOpKernel {
73 public:
ResourceApplyProximalGradientDescent(OpKernelConstruction * ctx)74 explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx)
75 : XlaOpKernel(ctx) {
76 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
77 }
78
Compile(XlaOpKernelContext * ctx)79 void Compile(XlaOpKernelContext* ctx) override {
80 xla::XlaOp var;
81 TensorShape var_shape;
82 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
83
84 TensorShape alpha_shape = ctx->InputShape(1);
85 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
86 errors::InvalidArgument("alpha is not a scalar: ",
87 alpha_shape.DebugString()));
88 TensorShape l1_shape = ctx->InputShape(2);
89 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
90 errors::InvalidArgument("l1 is not a scalar: ",
91 l1_shape.DebugString()));
92 TensorShape l2_shape = ctx->InputShape(3);
93 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
94 errors::InvalidArgument("l2 is not a scalar: ",
95 l2_shape.DebugString()));
96 TensorShape delta_shape = ctx->InputShape(4);
97 OP_REQUIRES(
98 ctx, var_shape.IsSameSize(delta_shape),
99 errors::InvalidArgument("var and delta do not have the same shape: ",
100 var_shape.DebugString(), " vs ",
101 delta_shape.DebugString()));
102 xla::XlaOp alpha = ctx->Input(1);
103 xla::XlaOp l1 = ctx->Input(2);
104 xla::XlaOp l2 = ctx->Input(3);
105 xla::XlaOp delta = ctx->Input(4);
106 var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta);
107 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
108 }
109
110 private:
111 DataType dtype_;
112 };
113 REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent")
114 .TypeConstraint("T", kFloatAndComplexTypes),
115 ResourceApplyProximalGradientDescent);
116
117 class ResourceApplyMomentum : public XlaOpKernel {
118 public:
ResourceApplyMomentum(OpKernelConstruction * ctx)119 explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
120 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
121 }
122
Compile(XlaOpKernelContext * ctx)123 void Compile(XlaOpKernelContext* ctx) override {
124 DataType type = ctx->input_type(2);
125
126 TensorShape var_shape, accum_shape;
127 xla::XlaOp var, accum;
128 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
129 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
130
131 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
132 errors::InvalidArgument(
133 "var and accum do not have the same shape",
134 var_shape.DebugString(), " ", accum_shape.DebugString()));
135
136 TensorShape lr_shape = ctx->InputShape(2);
137 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
138 errors::InvalidArgument("lr is not a scalar: ",
139 lr_shape.DebugString()));
140
141 TensorShape grad_shape = ctx->InputShape(3);
142 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
143 errors::InvalidArgument(
144 "var and grad do not have the same shape",
145 var_shape.DebugString(), " ", grad_shape.DebugString()));
146
147 TensorShape momentum_shape = ctx->InputShape(4);
148 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
149 errors::InvalidArgument("momentum is not a scalar: ",
150 momentum_shape.DebugString()));
151
152 xla::XlaOp lr = ctx->Input(2);
153 xla::XlaOp grad = ctx->Input(3);
154 xla::XlaOp momentum = ctx->Input(4);
155
156 accum = accum * momentum + grad;
157 if (use_nesterov_) {
158 // See https://github.com/tensorflow/tensorflow/pull/2798 for an
159 // explanation of the reparameterization used here.
160 var = var - (grad * lr + accum * momentum * lr);
161 } else {
162 var = var - accum * lr;
163 }
164 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
165 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
166 }
167
168 private:
169 bool use_nesterov_;
170 };
171 REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
172 ResourceApplyMomentum);
173
174 class ResourceApplyKerasMomentum : public XlaOpKernel {
175 public:
ResourceApplyKerasMomentum(OpKernelConstruction * ctx)176 explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx)
177 : XlaOpKernel(ctx) {
178 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
179 }
180
Compile(XlaOpKernelContext * ctx)181 void Compile(XlaOpKernelContext* ctx) override {
182 DataType type = ctx->input_type(2);
183
184 TensorShape var_shape, accum_shape;
185 xla::XlaOp var, accum;
186 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
187 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
188
189 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
190 errors::InvalidArgument(
191 "var and accum do not have the same shape",
192 var_shape.DebugString(), " ", accum_shape.DebugString()));
193
194 TensorShape lr_shape = ctx->InputShape(2);
195 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
196 errors::InvalidArgument("lr is not a scalar: ",
197 lr_shape.DebugString()));
198
199 TensorShape grad_shape = ctx->InputShape(3);
200 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
201 errors::InvalidArgument(
202 "var and grad do not have the same shape",
203 var_shape.DebugString(), " ", grad_shape.DebugString()));
204
205 TensorShape momentum_shape = ctx->InputShape(4);
206 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
207 errors::InvalidArgument("momentum is not a scalar: ",
208 momentum_shape.DebugString()));
209
210 xla::XlaOp lr = ctx->Input(2);
211 xla::XlaOp grad = ctx->Input(3);
212 xla::XlaOp momentum = ctx->Input(4);
213
214 accum = accum * momentum - grad * lr;
215 if (use_nesterov_) {
216 // See https://github.com/tensorflow/tensorflow/pull/2798 for an
217 // explanation of the reparameterization used here.
218 var = var + accum * momentum - grad * lr;
219 } else {
220 var = var + accum;
221 }
222 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
223 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
224 }
225
226 private:
227 bool use_nesterov_;
228 };
229 REGISTER_XLA_OP(Name("ResourceApplyKerasMomentum")
230 .TypeConstraint("T", kFloatAndComplexTypes),
231 ResourceApplyKerasMomentum);
232
233 class ResourceApplyAdagrad : public XlaOpKernel {
234 public:
ResourceApplyAdagrad(OpKernelConstruction * ctx)235 explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
236 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
237 }
238
Compile(XlaOpKernelContext * ctx)239 void Compile(XlaOpKernelContext* ctx) override {
240 DataType type = ctx->input_type(2);
241
242 TensorShape var_shape, accum_shape;
243 xla::XlaOp var, accum;
244 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
245 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
246
247 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
248 errors::InvalidArgument(
249 "var and accum do not have the same shape",
250 var_shape.DebugString(), " ", accum_shape.DebugString()));
251
252 TensorShape lr_shape = ctx->InputShape(2);
253 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
254 errors::InvalidArgument("lr is not a scalar: ",
255 lr_shape.DebugString()));
256
257 TensorShape grad_shape = ctx->InputShape(3);
258 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
259 errors::InvalidArgument(
260 "var and grad do not have the same shape",
261 var_shape.DebugString(), " ", grad_shape.DebugString()));
262
263 xla::XlaOp lr = ctx->Input(2);
264 xla::XlaOp grad = ctx->Input(3);
265
266 if (update_slots_) {
267 accum = accum + xla::Square(grad);
268 }
269 var = var - grad * lr * xla::Rsqrt(accum);
270 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
271 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
272 }
273
274 private:
275 bool update_slots_;
276 };
277 REGISTER_XLA_OP(
278 Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatAndComplexTypes),
279 ResourceApplyAdagrad);
280
281 class ResourceApplyAdagradV2 : public XlaOpKernel {
282 public:
ResourceApplyAdagradV2(OpKernelConstruction * ctx)283 explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
284 : XlaOpKernel(ctx) {
285 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
286 }
287
Compile(XlaOpKernelContext * ctx)288 void Compile(XlaOpKernelContext* ctx) override {
289 DataType type = ctx->input_type(2);
290
291 TensorShape var_shape, accum_shape;
292 xla::XlaOp var, accum;
293 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
294 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
295
296 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
297 errors::InvalidArgument(
298 "var and accum do not have the same shape",
299 var_shape.DebugString(), " ", accum_shape.DebugString()));
300
301 TensorShape lr_shape = ctx->InputShape(2);
302 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
303 errors::InvalidArgument("lr is not a scalar: ",
304 lr_shape.DebugString()));
305
306 TensorShape epsilon_shape = ctx->InputShape(3);
307 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
308 errors::InvalidArgument("epsilon is not a scalar: ",
309 epsilon_shape.DebugString()));
310
311 TensorShape grad_shape = ctx->InputShape(4);
312 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
313 errors::InvalidArgument(
314 "var and grad do not have the same shape",
315 var_shape.DebugString(), " ", grad_shape.DebugString()));
316
317 xla::XlaOp lr = ctx->Input(2);
318 xla::XlaOp epsilon = ctx->Input(3);
319 xla::XlaOp grad = ctx->Input(4);
320
321 if (update_slots_) {
322 accum = accum + xla::Square(grad);
323 }
324 var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
325 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
326 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
327 }
328
329 private:
330 bool update_slots_;
331 };
332 REGISTER_XLA_OP(
333 Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatAndComplexTypes),
334 ResourceApplyAdagradV2);
335
336 class ResourceApplyProximalAdagrad : public XlaOpKernel {
337 public:
ResourceApplyProximalAdagrad(OpKernelConstruction * ctx)338 explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
339 : XlaOpKernel(ctx) {
340 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
341 }
342
Compile(XlaOpKernelContext * ctx)343 void Compile(XlaOpKernelContext* ctx) override {
344 TensorShape var_shape, accum_shape;
345 xla::XlaOp var, accum;
346 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
347 OP_REQUIRES_OK(ctx,
348 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
349
350 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
351 errors::InvalidArgument(
352 "var and accum do not have the same shape",
353 var_shape.DebugString(), " ", accum_shape.DebugString()));
354
355 TensorShape lr_shape = ctx->InputShape(2);
356 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
357 errors::InvalidArgument("lr is not a scalar: ",
358 lr_shape.DebugString()));
359 TensorShape l1_shape = ctx->InputShape(3);
360 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
361 errors::InvalidArgument("l1 is not a scalar: ",
362 l1_shape.DebugString()));
363 TensorShape l2_shape = ctx->InputShape(4);
364 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
365 errors::InvalidArgument("l2 is not a scalar: ",
366 l2_shape.DebugString()));
367 TensorShape grad_shape = ctx->InputShape(5);
368 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
369 errors::InvalidArgument(
370 "var and grad do not have the same shape: ",
371 var_shape.DebugString(), " vs ", grad_shape.DebugString()));
372
373 xla::XlaOp lr = ctx->Input(2);
374 xla::XlaOp l1 = ctx->Input(3);
375 xla::XlaOp l2 = ctx->Input(4);
376 xla::XlaOp grad = ctx->Input(5);
377 accum = accum + xla::Square(grad);
378 // Adagrad learning rate.
379 xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum);
380 var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad);
381 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
382 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
383 }
384
385 private:
386 DataType dtype_;
387 };
388 REGISTER_XLA_OP(Name("ResourceApplyProximalAdagrad")
389 .TypeConstraint("T", kFloatAndComplexTypes),
390 ResourceApplyProximalAdagrad);
391
392 class ResourceApplyAdagradDA : public XlaOpKernel {
393 public:
ResourceApplyAdagradDA(OpKernelConstruction * ctx)394 explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
395 : XlaOpKernel(ctx) {
396 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
397 }
398
Compile(XlaOpKernelContext * ctx)399 void Compile(XlaOpKernelContext* ctx) override {
400 TensorShape var_shape, accum_shape, squared_accum_shape;
401 xla::XlaOp var, accum, squared_accum;
402 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
403 OP_REQUIRES_OK(ctx,
404 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
405 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
406 &squared_accum));
407 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
408 errors::InvalidArgument(
409 "var and accum do not have the same shape",
410 var_shape.DebugString(), " ", accum_shape.DebugString()));
411 OP_REQUIRES(
412 ctx, var_shape.IsSameSize(squared_accum_shape),
413 errors::InvalidArgument(
414 "var and squared accum do not have the same shape",
415 var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
416
417 TensorShape grad_shape = ctx->InputShape(3);
418 TensorShape lr_shape = ctx->InputShape(4);
419 TensorShape l1_shape = ctx->InputShape(5);
420 TensorShape l2_shape = ctx->InputShape(6);
421 TensorShape global_step_shape = ctx->InputShape(7);
422
423 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
424 errors::InvalidArgument(
425 "var and grad do not have the same shape",
426 var_shape.DebugString(), " ", grad_shape.DebugString()));
427 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
428 errors::InvalidArgument("lr is not a scalar: ",
429 lr_shape.DebugString()));
430 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
431 errors::InvalidArgument("l1 is not a scalar: ",
432 l1_shape.DebugString()));
433 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
434 errors::InvalidArgument("l2 is not a scalar: ",
435 l2_shape.DebugString()));
436 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
437 errors::InvalidArgument("global step is not a scalar: ",
438 global_step_shape.DebugString()));
439
440 xla::XlaOp grad = ctx->Input(3);
441 xla::XlaOp lr = ctx->Input(4);
442 xla::XlaOp l1 = ctx->Input(5);
443 xla::XlaOp l2 = ctx->Input(6);
444 xla::XlaOp global_step =
445 XlaHelpers::ConvertElementType(ctx->Input(7), dtype_);
446
447 accum = accum + grad;
448 squared_accum = squared_accum + xla::Square(grad);
449 xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
450 xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
451 xla::XlaOp l1_le_zero = -lr * accum / denominator;
452 xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
453 xla::Max(xla::Abs(accum) - global_step * l1, zero) /
454 denominator;
455
456 var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
457 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
458 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
459 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
460 }
461
462 private:
463 DataType dtype_;
464 };
465 REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
466 ResourceApplyAdagradDA);
467
468 class ResourceApplyAdam : public XlaOpKernel {
469 public:
ResourceApplyAdam(OpKernelConstruction * ctx)470 explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
471 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
472 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
473 }
474
Compile(XlaOpKernelContext * ctx)475 void Compile(XlaOpKernelContext* ctx) override {
476 TensorShape var_shape, m_shape, v_shape;
477 xla::XlaOp var, m, v;
478 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
479 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
480 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
481
482 TensorShape beta1_power_shape = ctx->InputShape(3);
483 TensorShape beta2_power_shape = ctx->InputShape(4);
484 TensorShape lr_shape = ctx->InputShape(5);
485 TensorShape beta1_shape = ctx->InputShape(6);
486 TensorShape beta2_shape = ctx->InputShape(7);
487 TensorShape epsilon_shape = ctx->InputShape(8);
488 TensorShape grad_shape = ctx->InputShape(9);
489
490 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
491 errors::InvalidArgument("beta1_power is not a scalar: ",
492 beta1_power_shape.DebugString()));
493 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape),
494 errors::InvalidArgument("beta2_power is not a scalar: ",
495 beta2_power_shape.DebugString()));
496 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
497 errors::InvalidArgument("lr is not a scalar : ",
498 lr_shape.DebugString()));
499 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
500 errors::InvalidArgument("beta1 is not a scalar: ",
501 beta1_shape.DebugString()));
502 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
503 errors::InvalidArgument("beta2 is not a scalar: ",
504 beta2_shape.DebugString()));
505 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
506 errors::InvalidArgument("epsilon is not a scalar: ",
507 epsilon_shape.DebugString()));
508
509 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
510 errors::InvalidArgument("var and m do not have the same shape",
511 var_shape.DebugString(), " ",
512 m_shape.DebugString()));
513 OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
514 errors::InvalidArgument("var and v do not have the same shape",
515 var_shape.DebugString(), " ",
516 v_shape.DebugString()));
517 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
518 errors::InvalidArgument(
519 "var and grad do not have the same shape",
520 var_shape.DebugString(), " ", grad_shape.DebugString()));
521
522 xla::XlaOp beta1_power = ctx->Input(3);
523 xla::XlaOp beta2_power = ctx->Input(4);
524 xla::XlaOp lr = ctx->Input(5);
525 xla::XlaOp beta1 = ctx->Input(6);
526 xla::XlaOp beta2 = ctx->Input(7);
527 xla::XlaOp epsilon = ctx->Input(8);
528 xla::XlaOp grad = ctx->Input(9);
529
530 // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
531 // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
532 // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
533 // if not use_nesterov:
534 // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
535 // if use_nesterov:
536 // variable <- variable - alpha * (m_t * beta1 + (1 - beta1) * g_t) /
537 // (sqrt(v_t) + epsilon)
538
539 xla::XlaBuilder* b = ctx->builder();
540 xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
541
542 xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
543 auto m_t = m + (grad - m) * (one - beta1);
544 v = v + (xla::Square(grad) - v) * (one - beta2);
545 if (use_nesterov_) {
546 var = var - alpha * (m_t * beta1 + (one - beta1) * grad) /
547 (xla::Sqrt(v) + epsilon);
548 } else {
549 var = var - m_t * alpha / (xla::Sqrt(v) + epsilon);
550 }
551
552 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
553 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m_t));
554 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
555 }
556
557 private:
558 DataType dtype_;
559 bool use_nesterov_;
560 };
561 REGISTER_XLA_OP(
562 Name("ResourceApplyAdam").TypeConstraint("T", kFloatAndComplexTypes),
563 ResourceApplyAdam);
564
565 class ResourceApplyAdaMax : public XlaOpKernel {
566 public:
ResourceApplyAdaMax(OpKernelConstruction * ctx)567 explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
568 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
569 }
570
Compile(XlaOpKernelContext * ctx)571 void Compile(XlaOpKernelContext* ctx) override {
572 TensorShape var_shape, m_shape, v_shape;
573 xla::XlaOp var, m, v;
574 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
575 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
576 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
577
578 TensorShape beta1_power_shape = ctx->InputShape(3);
579 TensorShape lr_shape = ctx->InputShape(4);
580 TensorShape beta1_shape = ctx->InputShape(5);
581 TensorShape beta2_shape = ctx->InputShape(6);
582 TensorShape epsilon_shape = ctx->InputShape(7);
583 TensorShape grad_shape = ctx->InputShape(8);
584
585 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
586 errors::InvalidArgument("beta1_power is not a scalar: ",
587 beta1_power_shape.DebugString()));
588 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
589 errors::InvalidArgument("lr is not a scalar : ",
590 lr_shape.DebugString()));
591 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
592 errors::InvalidArgument("beta1 is not a scalar: ",
593 beta1_shape.DebugString()));
594 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
595 errors::InvalidArgument("beta2 is not a scalar: ",
596 beta2_shape.DebugString()));
597 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
598 errors::InvalidArgument("epsilon is not a scalar: ",
599 epsilon_shape.DebugString()));
600 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
601 errors::InvalidArgument("var and m do not have the same shape",
602 var_shape.DebugString(), " ",
603 m_shape.DebugString()));
604 OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
605 errors::InvalidArgument("var and v do not have the same shape",
606 var_shape.DebugString(), " ",
607 v_shape.DebugString()));
608 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
609 errors::InvalidArgument(
610 "var and grad do not have the same shape",
611 var_shape.DebugString(), " ", grad_shape.DebugString()));
612
613 xla::XlaOp beta1_power = ctx->Input(3);
614 xla::XlaOp lr = ctx->Input(4);
615 xla::XlaOp beta1 = ctx->Input(5);
616 xla::XlaOp beta2 = ctx->Input(6);
617 xla::XlaOp epsilon = ctx->Input(7);
618 xla::XlaOp grad = ctx->Input(8);
619
620 xla::XlaOp one = xla::ScalarLike(lr, 1.0);
621 m = beta1 * m + (one - beta1) * grad;
622 v = xla::Max(beta2 * v, xla::Abs(grad));
623 var = var - lr / (one - beta1_power) * (m / (v + epsilon));
624
625 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
626 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
627 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
628 }
629
630 private:
631 DataType dtype_;
632 };
633 REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
634 ResourceApplyAdaMax);
635
636 class ResourceApplyRMSProp : public XlaOpKernel {
637 public:
ResourceApplyRMSProp(OpKernelConstruction * ctx)638 explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
639 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
640 }
641
Compile(XlaOpKernelContext * ctx)642 void Compile(XlaOpKernelContext* ctx) override {
643 TensorShape var_shape, ms_shape, mom_shape, mg_shape;
644 xla::XlaOp var, ms, mom, mg;
645 OP_REQUIRES_OK(ctx,
646 ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
647 if (centered_) {
648 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
649 }
650 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
651 OP_REQUIRES_OK(ctx,
652 ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
653
654 TensorShape lr_shape = ctx->InputShape("lr");
655 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
656 errors::InvalidArgument("lr is not a scalar: ",
657 lr_shape.DebugString()));
658 TensorShape rho_shape = ctx->InputShape("rho");
659 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
660 errors::InvalidArgument("rho is not a scalar: ",
661 rho_shape.DebugString()));
662 TensorShape momentum_shape = ctx->InputShape("momentum");
663 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
664 errors::InvalidArgument("momentum is not a scalar: ",
665 momentum_shape.DebugString()));
666 TensorShape epsilon_shape = ctx->InputShape("epsilon");
667 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
668 errors::InvalidArgument("epsilon is not a scalar: ",
669 epsilon_shape.DebugString()));
670 TensorShape grad_shape = ctx->InputShape("grad");
671
672 // var should be the same shape as mom and ms.
673 OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
674 errors::InvalidArgument("var and ms do not have the same shape",
675 var_shape.DebugString(), " ",
676 ms_shape.DebugString()));
677 OP_REQUIRES(ctx, var_shape.IsSameSize(mom_shape),
678 errors::InvalidArgument(
679 "var and mom do not have the same shape",
680 var_shape.DebugString(), " ", mom_shape.DebugString()));
681 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
682 errors::InvalidArgument(
683 "var and grad do not have the same shape",
684 var_shape.DebugString(), " ", grad_shape.DebugString()));
685
686 xla::XlaOp lr = ctx->Input("lr");
687 xla::XlaOp rho = ctx->Input("rho");
688 xla::XlaOp momentum = ctx->Input("momentum");
689 xla::XlaOp epsilon = ctx->Input("epsilon");
690 xla::XlaOp grad = ctx->Input("grad");
691
692 // ms <- rho * ms_{t-1} + (1-rho) * grad * grad
693 // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
694 // var <- var - mom
695 //
696 // We use an alternate formulation of the ms equation:
697 //
698 // ms <- ms + (grad**2 - ms) * (1 - rho)
699 //
700 // Which expands to:
701 //
702 // ms <- ms + grad**2 - rho * grad ** 2 - ms + ms * rho
703 //
704 // Which simplifies to:
705 //
706 // ms <- grad**2 (1 - rho) + ms * rho
707 //
708 // Which is the equation listed above.
709 xla::XlaOp one = xla::ScalarLike(ms, 1.0);
710 xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
711 xla::XlaOp denominator;
712 if (centered_) {
713 mg = grad * (one - rho) + mg * rho;
714 denominator = new_ms - xla::Square(mg) + epsilon;
715 } else {
716 denominator = new_ms + epsilon;
717 }
718 xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
719 xla::XlaOp new_var = var - new_mom;
720
721 OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
722 if (centered_) {
723 OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
724 }
725 OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
726 OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
727 }
728
729 protected:
730 bool centered_ = false;
731
732 private:
733 DataType dtype_;
734 };
735 REGISTER_XLA_OP(
736 Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatAndComplexTypes),
737 ResourceApplyRMSProp);
738
739 class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
740 public:
ResourceApplyCenteredRMSProp(OpKernelConstruction * ctx)741 explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
742 : ResourceApplyRMSProp(ctx) {
743 centered_ = true;
744 }
745 };
746 REGISTER_XLA_OP(Name("ResourceApplyCenteredRMSProp")
747 .TypeConstraint("T", kFloatAndComplexTypes),
748 ResourceApplyCenteredRMSProp);
749
CompileFtrl(XlaOpKernelContext * ctx,DataType dtype,bool has_l2_shrinkage,bool multiply_linear_by_lr)750 void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage,
751 bool multiply_linear_by_lr) {
752 xla::XlaBuilder* b = ctx->builder();
753
754 TensorShape var_shape, accum_shape, linear_shape;
755 xla::XlaOp var, accum, linear;
756 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
757 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
758 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
759
760 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
761 errors::InvalidArgument(
762 "var and accum do not have the same shape",
763 var_shape.DebugString(), " ", accum_shape.DebugString()));
764
765 OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape),
766 errors::InvalidArgument(
767 "var and linear do not have the same shape",
768 var_shape.DebugString(), " ", linear_shape.DebugString()));
769
770 TensorShape grad_shape = ctx->InputShape(3);
771 TensorShape lr_shape = ctx->InputShape(4);
772 TensorShape l1_shape = ctx->InputShape(5);
773 TensorShape l2_shape = ctx->InputShape(6);
774 TensorShape l2_shrinkage_shape;
775 TensorShape lr_power_shape;
776 if (has_l2_shrinkage) {
777 l2_shrinkage_shape = ctx->InputShape(7);
778 lr_power_shape = ctx->InputShape(8);
779 } else {
780 lr_power_shape = ctx->InputShape(7);
781 }
782
783 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
784 errors::InvalidArgument("var and grad do not have the same shape",
785 var_shape.DebugString(), " ",
786 grad_shape.DebugString()));
787
788 OP_REQUIRES(
789 ctx, TensorShapeUtils::IsScalar(lr_shape),
790 errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString()));
791
792 OP_REQUIRES(
793 ctx, TensorShapeUtils::IsScalar(l1_shape),
794 errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString()));
795
796 OP_REQUIRES(
797 ctx, TensorShapeUtils::IsScalar(l2_shape),
798 errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString()));
799
800 if (has_l2_shrinkage) {
801 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape),
802 errors::InvalidArgument("l2_shrinkage is not a scalar: ",
803 l2_shrinkage_shape.DebugString()));
804 }
805
806 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape),
807 errors::InvalidArgument("lr_power is not a scalar: ",
808 lr_power_shape.DebugString()));
809
810 xla::XlaOp grad = ctx->Input(3);
811 xla::XlaOp lr = ctx->Input(4);
812 xla::XlaOp l1 = ctx->Input(5);
813 xla::XlaOp l2 = ctx->Input(6);
814 xla::XlaOp l2_shrinkage;
815 xla::XlaOp lr_power;
816 if (has_l2_shrinkage) {
817 l2_shrinkage = ctx->Input(7);
818 lr_power = ctx->Input(8);
819 } else {
820 lr_power = ctx->Input(7);
821 }
822
823 // grad_to_use = grad + 2 * l2_shrinkage * var
824 // new_accum = accum + grad * grad
825 // linear += grad_to_use -
826 // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
827 // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
828 // linear_clipped = clamp linear in [-l1, l1]
829 // var = (linear_clipped - linear) / quadratic
830 // accum = new_accum
831
832 xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
833 xla::XlaOp grad_to_use;
834 if (has_l2_shrinkage) {
835 grad_to_use = grad + two * l2_shrinkage * var;
836 } else {
837 grad_to_use = grad;
838 }
839
840 xla::XlaOp new_accum = accum + xla::Square(grad);
841 xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
842 xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
843 if (multiply_linear_by_lr) {
844 linear =
845 linear + grad_to_use * lr - (new_accum_lr_pow - accum_lr_pow) * var;
846 } else {
847 linear =
848 linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
849 }
850 xla::XlaOp linear_clipped =
851 (multiply_linear_by_lr ? xla::Clamp(-l1 * lr, linear, l1 * lr)
852 : xla::Clamp(-l1, linear, l1));
853 xla::XlaOp quadratic =
854 (multiply_linear_by_lr ? new_accum_lr_pow + two * l2 * lr
855 : new_accum_lr_pow / lr + two * l2);
856 var = (linear_clipped - linear) / quadratic;
857 accum = new_accum;
858
859 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
860 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum));
861 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear));
862 }
863
864 class ResourceApplyFtrl : public XlaOpKernel {
865 public:
ResourceApplyFtrl(OpKernelConstruction * ctx)866 explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
867 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
868 OP_REQUIRES_OK(
869 ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
870 }
871
Compile(XlaOpKernelContext * ctx)872 void Compile(XlaOpKernelContext* ctx) override {
873 CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false,
874 /*multiply_linear_by_lr=*/multiply_linear_by_lr_);
875 }
876
877 private:
878 DataType dtype_;
879
880 // Whether to keep the "linear" slot variable multiplied by the learning rate.
881 bool multiply_linear_by_lr_;
882 };
883 REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
884 ResourceApplyFtrl);
885
886 class ResourceApplyFtrlV2 : public XlaOpKernel {
887 public:
ResourceApplyFtrlV2(OpKernelConstruction * ctx)888 explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
889 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
890 OP_REQUIRES_OK(
891 ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
892 }
893
Compile(XlaOpKernelContext * ctx)894 void Compile(XlaOpKernelContext* ctx) override {
895 CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true,
896 /*multiply_linear_by_lr=*/multiply_linear_by_lr_);
897 }
898
899 private:
900 DataType dtype_;
901
902 // Whether to keep the "linear" slot variable multiplied by the learning rate.
903 bool multiply_linear_by_lr_;
904 };
905 REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
906 ResourceApplyFtrlV2);
907
908 class ResourceApplyAdadelta : public XlaOpKernel {
909 public:
ResourceApplyAdadelta(OpKernelConstruction * ctx)910 explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
911 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
912 }
913
Compile(XlaOpKernelContext * ctx)914 void Compile(XlaOpKernelContext* ctx) override {
915 TensorShape var_shape, accum_shape, accum_update_shape;
916 xla::XlaOp var, accum, accum_update;
917 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
918 OP_REQUIRES_OK(ctx,
919 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
920 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
921 &accum_update));
922
923 TensorShape lr_shape = ctx->InputShape(3);
924 TensorShape rho_shape = ctx->InputShape(4);
925 TensorShape epsilon_shape = ctx->InputShape(5);
926 TensorShape grad_shape = ctx->InputShape(6);
927
928 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
929 errors::InvalidArgument("lr is not a scalar: ",
930 lr_shape.DebugString()));
931
932 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
933 errors::InvalidArgument("rho is not a scalar: ",
934 rho_shape.DebugString()));
935
936 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
937 errors::InvalidArgument("epsilon is not a scalar: ",
938 epsilon_shape.DebugString()));
939
940 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
941 errors::InvalidArgument(
942 "var and accum do not have the same shape",
943 var_shape.DebugString(), " ", accum_shape.DebugString()));
944
945 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
946 errors::InvalidArgument(
947 "var and grad do not have the same shape",
948 var_shape.DebugString(), " ", grad_shape.DebugString()));
949
950 xla::XlaOp lr = ctx->Input(3);
951 xla::XlaOp rho = ctx->Input(4);
952 xla::XlaOp epsilon = ctx->Input(5);
953 xla::XlaOp grad = ctx->Input(6);
954
955 xla::XlaBuilder* b = ctx->builder();
956 xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
957
958 accum = rho * accum + (one - rho) * xla::Square(grad);
959 xla::XlaOp update =
960 xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad;
961 accum_update = rho * accum_update + (one - rho) * xla::Square(update);
962 var = var - update * lr;
963 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
964 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
965 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
966 }
967
968 private:
969 DataType dtype_;
970 };
971 REGISTER_XLA_OP(
972 Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatAndComplexTypes),
973 ResourceApplyAdadelta);
974
975 class ResourceApplySignBase : public XlaOpKernel {
976 public:
ResourceApplySignBase(OpKernelConstruction * ctx)977 explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
978 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
979 }
980
Compile(XlaOpKernelContext * ctx)981 void Compile(XlaOpKernelContext* ctx) override {
982 TensorShape var_shape, m_shape;
983 xla::XlaOp var, m;
984 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
985 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
986 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
987 errors::InvalidArgument("var and m do not have the same shape",
988 var_shape.DebugString(), " ",
989 m_shape.DebugString()));
990 TensorShape grad_shape = ctx->InputShape(6);
991 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
992 errors::InvalidArgument(
993 "var and grad do not have the same shape",
994 var_shape.DebugString(), " ", grad_shape.DebugString()));
995 CheckScalarParams(ctx);
996
997 xla::XlaOp lr = ctx->Input(2);
998 xla::XlaOp alpha = ctx->Input(3);
999 xla::XlaOp sign_decay = ctx->Input(4);
1000 xla::XlaOp beta = ctx->Input(5);
1001 xla::XlaOp grad = ctx->Input(6);
1002
1003 m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
1004 xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
1005
1006 xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
1007 var = var - lr * grad_scale * grad;
1008 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
1009 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
1010 }
1011
CheckScalarParams(XlaOpKernelContext * ctx)1012 virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
1013 TensorShape lr_shape = ctx->InputShape(2);
1014 TensorShape sign_decay_shape = ctx->InputShape(4);
1015 TensorShape beta_shape = ctx->InputShape(5);
1016
1017 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
1018 errors::InvalidArgument("lr is not a scalar: ",
1019 lr_shape.DebugString()));
1020
1021 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
1022 errors::InvalidArgument("sign_decay is not a scalar: ",
1023 sign_decay_shape.DebugString()));
1024
1025 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
1026 errors::InvalidArgument("beta is not a scalar: ",
1027 beta_shape.DebugString()));
1028 }
1029
1030 virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
1031 xla::XlaOp decay) = 0;
1032
1033 private:
1034 DataType dtype_;
1035 };
1036
1037 class ResourceApplyAddSign : public ResourceApplySignBase {
1038 public:
ResourceApplyAddSign(OpKernelConstruction * ctx)1039 explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
1040 : ResourceApplySignBase(ctx) {}
1041
CheckScalarParams(XlaOpKernelContext * ctx)1042 void CheckScalarParams(XlaOpKernelContext* ctx) override {
1043 ResourceApplySignBase::CheckScalarParams(ctx);
1044 TensorShape alpha_shape = ctx->InputShape(3);
1045 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
1046 errors::InvalidArgument("alpha is not a scalar: ",
1047 alpha_shape.DebugString()));
1048 }
1049
ComputeGradientScale(xla::XlaOp alpha,xla::XlaOp decay)1050 xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
1051 return alpha + decay;
1052 }
1053 };
1054 REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
1055 ResourceApplyAddSign);
1056
1057 class ResourceApplyPowerSign : public ResourceApplySignBase {
1058 public:
ResourceApplyPowerSign(OpKernelConstruction * ctx)1059 explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
1060 : ResourceApplySignBase(ctx) {}
1061
CheckScalarParams(XlaOpKernelContext * ctx)1062 void CheckScalarParams(XlaOpKernelContext* ctx) override {
1063 ResourceApplySignBase::CheckScalarParams(ctx);
1064 TensorShape logbase_shape = ctx->InputShape(3);
1065 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
1066 errors::InvalidArgument("logbase is not a scalar: ",
1067 logbase_shape.DebugString()));
1068 }
1069
ComputeGradientScale(xla::XlaOp alpha,xla::XlaOp decay)1070 xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
1071 return xla::Exp(alpha * decay);
1072 }
1073 };
1074 REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
1075 ResourceApplyPowerSign);
1076
1077 } // namespace
1078 } // namespace tensorflow
1079