xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/training_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/math.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/types.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 class ResourceApplyGradientDescent : public XlaOpKernel {
30  public:
ResourceApplyGradientDescent(OpKernelConstruction * ctx)31   explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
32       : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)33   void Compile(XlaOpKernelContext* ctx) override {
34     xla::XlaOp handle;
35     DataType type = ctx->input_type(1);
36     TensorShape var_shape;
37     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
38 
39     TensorShape alpha_shape = ctx->InputShape(1);
40     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
41                 errors::InvalidArgument("alpha is not a scalar: ",
42                                         alpha_shape.DebugString()));
43 
44     TensorShape delta_shape = ctx->InputShape(2);
45     OP_REQUIRES(
46         ctx, var_shape.IsSameSize(delta_shape),
47         errors::InvalidArgument("var and delta do not have the same shape: ",
48                                 var_shape.DebugString(), " vs ",
49                                 delta_shape.DebugString()));
50 
51     handle = handle - ctx->Input(1) * ctx->Input(2);
52     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
53   }
54 };
55 REGISTER_XLA_OP(Name("ResourceApplyGradientDescent")
56                     .TypeConstraint("T", kFloatAndComplexTypes),
57                 ResourceApplyGradientDescent);
58 
ProximalGradientDescentUpdate(xla::XlaOp var,xla::XlaOp lr,xla::XlaOp l1,xla::XlaOp l2,xla::XlaOp grad)59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
60                                          xla::XlaOp l1, xla::XlaOp l2,
61                                          xla::XlaOp grad) {
62   xla::XlaOp one = xla::ScalarLike(lr, 1.0);
63   xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
64   xla::XlaOp prox_var = var - grad * lr;
65   xla::XlaOp l1_gt_zero =
66       xla::Sign(prox_var) * xla::Max(xla::Abs(prox_var) - lr * l1, zero);
67   xla::XlaOp l1_le_zero = prox_var;
68   return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero) /
69          (one + lr * l2);
70 }
71 
72 class ResourceApplyProximalGradientDescent : public XlaOpKernel {
73  public:
ResourceApplyProximalGradientDescent(OpKernelConstruction * ctx)74   explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx)
75       : XlaOpKernel(ctx) {
76     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
77   }
78 
Compile(XlaOpKernelContext * ctx)79   void Compile(XlaOpKernelContext* ctx) override {
80     xla::XlaOp var;
81     TensorShape var_shape;
82     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
83 
84     TensorShape alpha_shape = ctx->InputShape(1);
85     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
86                 errors::InvalidArgument("alpha is not a scalar: ",
87                                         alpha_shape.DebugString()));
88     TensorShape l1_shape = ctx->InputShape(2);
89     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
90                 errors::InvalidArgument("l1 is not a scalar: ",
91                                         l1_shape.DebugString()));
92     TensorShape l2_shape = ctx->InputShape(3);
93     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
94                 errors::InvalidArgument("l2 is not a scalar: ",
95                                         l2_shape.DebugString()));
96     TensorShape delta_shape = ctx->InputShape(4);
97     OP_REQUIRES(
98         ctx, var_shape.IsSameSize(delta_shape),
99         errors::InvalidArgument("var and delta do not have the same shape: ",
100                                 var_shape.DebugString(), " vs ",
101                                 delta_shape.DebugString()));
102     xla::XlaOp alpha = ctx->Input(1);
103     xla::XlaOp l1 = ctx->Input(2);
104     xla::XlaOp l2 = ctx->Input(3);
105     xla::XlaOp delta = ctx->Input(4);
106     var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta);
107     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
108   }
109 
110  private:
111   DataType dtype_;
112 };
113 REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent")
114                     .TypeConstraint("T", kFloatAndComplexTypes),
115                 ResourceApplyProximalGradientDescent);
116 
117 class ResourceApplyMomentum : public XlaOpKernel {
118  public:
ResourceApplyMomentum(OpKernelConstruction * ctx)119   explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
120     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
121   }
122 
Compile(XlaOpKernelContext * ctx)123   void Compile(XlaOpKernelContext* ctx) override {
124     DataType type = ctx->input_type(2);
125 
126     TensorShape var_shape, accum_shape;
127     xla::XlaOp var, accum;
128     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
129     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
130 
131     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
132                 errors::InvalidArgument(
133                     "var and accum do not have the same shape",
134                     var_shape.DebugString(), " ", accum_shape.DebugString()));
135 
136     TensorShape lr_shape = ctx->InputShape(2);
137     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
138                 errors::InvalidArgument("lr is not a scalar: ",
139                                         lr_shape.DebugString()));
140 
141     TensorShape grad_shape = ctx->InputShape(3);
142     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
143                 errors::InvalidArgument(
144                     "var and grad do not have the same shape",
145                     var_shape.DebugString(), " ", grad_shape.DebugString()));
146 
147     TensorShape momentum_shape = ctx->InputShape(4);
148     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
149                 errors::InvalidArgument("momentum is not a scalar: ",
150                                         momentum_shape.DebugString()));
151 
152     xla::XlaOp lr = ctx->Input(2);
153     xla::XlaOp grad = ctx->Input(3);
154     xla::XlaOp momentum = ctx->Input(4);
155 
156     accum = accum * momentum + grad;
157     if (use_nesterov_) {
158       // See https://github.com/tensorflow/tensorflow/pull/2798 for an
159       // explanation of the reparameterization used here.
160       var = var - (grad * lr + accum * momentum * lr);
161     } else {
162       var = var - accum * lr;
163     }
164     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
165     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
166   }
167 
168  private:
169   bool use_nesterov_;
170 };
171 REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
172                 ResourceApplyMomentum);
173 
174 class ResourceApplyKerasMomentum : public XlaOpKernel {
175  public:
ResourceApplyKerasMomentum(OpKernelConstruction * ctx)176   explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx)
177       : XlaOpKernel(ctx) {
178     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
179   }
180 
Compile(XlaOpKernelContext * ctx)181   void Compile(XlaOpKernelContext* ctx) override {
182     DataType type = ctx->input_type(2);
183 
184     TensorShape var_shape, accum_shape;
185     xla::XlaOp var, accum;
186     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
187     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
188 
189     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
190                 errors::InvalidArgument(
191                     "var and accum do not have the same shape",
192                     var_shape.DebugString(), " ", accum_shape.DebugString()));
193 
194     TensorShape lr_shape = ctx->InputShape(2);
195     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
196                 errors::InvalidArgument("lr is not a scalar: ",
197                                         lr_shape.DebugString()));
198 
199     TensorShape grad_shape = ctx->InputShape(3);
200     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
201                 errors::InvalidArgument(
202                     "var and grad do not have the same shape",
203                     var_shape.DebugString(), " ", grad_shape.DebugString()));
204 
205     TensorShape momentum_shape = ctx->InputShape(4);
206     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
207                 errors::InvalidArgument("momentum is not a scalar: ",
208                                         momentum_shape.DebugString()));
209 
210     xla::XlaOp lr = ctx->Input(2);
211     xla::XlaOp grad = ctx->Input(3);
212     xla::XlaOp momentum = ctx->Input(4);
213 
214     accum = accum * momentum - grad * lr;
215     if (use_nesterov_) {
216       // See https://github.com/tensorflow/tensorflow/pull/2798 for an
217       // explanation of the reparameterization used here.
218       var = var + accum * momentum - grad * lr;
219     } else {
220       var = var + accum;
221     }
222     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
223     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
224   }
225 
226  private:
227   bool use_nesterov_;
228 };
229 REGISTER_XLA_OP(Name("ResourceApplyKerasMomentum")
230                     .TypeConstraint("T", kFloatAndComplexTypes),
231                 ResourceApplyKerasMomentum);
232 
233 class ResourceApplyAdagrad : public XlaOpKernel {
234  public:
ResourceApplyAdagrad(OpKernelConstruction * ctx)235   explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
236     OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
237   }
238 
Compile(XlaOpKernelContext * ctx)239   void Compile(XlaOpKernelContext* ctx) override {
240     DataType type = ctx->input_type(2);
241 
242     TensorShape var_shape, accum_shape;
243     xla::XlaOp var, accum;
244     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
245     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
246 
247     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
248                 errors::InvalidArgument(
249                     "var and accum do not have the same shape",
250                     var_shape.DebugString(), " ", accum_shape.DebugString()));
251 
252     TensorShape lr_shape = ctx->InputShape(2);
253     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
254                 errors::InvalidArgument("lr is not a scalar: ",
255                                         lr_shape.DebugString()));
256 
257     TensorShape grad_shape = ctx->InputShape(3);
258     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
259                 errors::InvalidArgument(
260                     "var and grad do not have the same shape",
261                     var_shape.DebugString(), " ", grad_shape.DebugString()));
262 
263     xla::XlaOp lr = ctx->Input(2);
264     xla::XlaOp grad = ctx->Input(3);
265 
266     if (update_slots_) {
267       accum = accum + xla::Square(grad);
268     }
269     var = var - grad * lr * xla::Rsqrt(accum);
270     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
271     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
272   }
273 
274  private:
275   bool update_slots_;
276 };
277 REGISTER_XLA_OP(
278     Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatAndComplexTypes),
279     ResourceApplyAdagrad);
280 
281 class ResourceApplyAdagradV2 : public XlaOpKernel {
282  public:
ResourceApplyAdagradV2(OpKernelConstruction * ctx)283   explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx)
284       : XlaOpKernel(ctx) {
285     OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
286   }
287 
Compile(XlaOpKernelContext * ctx)288   void Compile(XlaOpKernelContext* ctx) override {
289     DataType type = ctx->input_type(2);
290 
291     TensorShape var_shape, accum_shape;
292     xla::XlaOp var, accum;
293     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
294     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
295 
296     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
297                 errors::InvalidArgument(
298                     "var and accum do not have the same shape",
299                     var_shape.DebugString(), " ", accum_shape.DebugString()));
300 
301     TensorShape lr_shape = ctx->InputShape(2);
302     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
303                 errors::InvalidArgument("lr is not a scalar: ",
304                                         lr_shape.DebugString()));
305 
306     TensorShape epsilon_shape = ctx->InputShape(3);
307     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
308                 errors::InvalidArgument("epsilon is not a scalar: ",
309                                         epsilon_shape.DebugString()));
310 
311     TensorShape grad_shape = ctx->InputShape(4);
312     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
313                 errors::InvalidArgument(
314                     "var and grad do not have the same shape",
315                     var_shape.DebugString(), " ", grad_shape.DebugString()));
316 
317     xla::XlaOp lr = ctx->Input(2);
318     xla::XlaOp epsilon = ctx->Input(3);
319     xla::XlaOp grad = ctx->Input(4);
320 
321     if (update_slots_) {
322       accum = accum + xla::Square(grad);
323     }
324     var = var - grad * lr / (xla::Sqrt(accum) + epsilon);
325     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
326     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
327   }
328 
329  private:
330   bool update_slots_;
331 };
332 REGISTER_XLA_OP(
333     Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatAndComplexTypes),
334     ResourceApplyAdagradV2);
335 
336 class ResourceApplyProximalAdagrad : public XlaOpKernel {
337  public:
ResourceApplyProximalAdagrad(OpKernelConstruction * ctx)338   explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
339       : XlaOpKernel(ctx) {
340     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
341   }
342 
Compile(XlaOpKernelContext * ctx)343   void Compile(XlaOpKernelContext* ctx) override {
344     TensorShape var_shape, accum_shape;
345     xla::XlaOp var, accum;
346     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
347     OP_REQUIRES_OK(ctx,
348                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
349 
350     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
351                 errors::InvalidArgument(
352                     "var and accum do not have the same shape",
353                     var_shape.DebugString(), " ", accum_shape.DebugString()));
354 
355     TensorShape lr_shape = ctx->InputShape(2);
356     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
357                 errors::InvalidArgument("lr is not a scalar: ",
358                                         lr_shape.DebugString()));
359     TensorShape l1_shape = ctx->InputShape(3);
360     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
361                 errors::InvalidArgument("l1 is not a scalar: ",
362                                         l1_shape.DebugString()));
363     TensorShape l2_shape = ctx->InputShape(4);
364     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
365                 errors::InvalidArgument("l2 is not a scalar: ",
366                                         l2_shape.DebugString()));
367     TensorShape grad_shape = ctx->InputShape(5);
368     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
369                 errors::InvalidArgument(
370                     "var and grad do not have the same shape: ",
371                     var_shape.DebugString(), " vs ", grad_shape.DebugString()));
372 
373     xla::XlaOp lr = ctx->Input(2);
374     xla::XlaOp l1 = ctx->Input(3);
375     xla::XlaOp l2 = ctx->Input(4);
376     xla::XlaOp grad = ctx->Input(5);
377     accum = accum + xla::Square(grad);
378     // Adagrad learning rate.
379     xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum);
380     var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad);
381     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
382     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
383   }
384 
385  private:
386   DataType dtype_;
387 };
388 REGISTER_XLA_OP(Name("ResourceApplyProximalAdagrad")
389                     .TypeConstraint("T", kFloatAndComplexTypes),
390                 ResourceApplyProximalAdagrad);
391 
392 class ResourceApplyAdagradDA : public XlaOpKernel {
393  public:
ResourceApplyAdagradDA(OpKernelConstruction * ctx)394   explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
395       : XlaOpKernel(ctx) {
396     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
397   }
398 
Compile(XlaOpKernelContext * ctx)399   void Compile(XlaOpKernelContext* ctx) override {
400     TensorShape var_shape, accum_shape, squared_accum_shape;
401     xla::XlaOp var, accum, squared_accum;
402     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
403     OP_REQUIRES_OK(ctx,
404                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
405     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
406                                                &squared_accum));
407     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
408                 errors::InvalidArgument(
409                     "var and accum do not have the same shape",
410                     var_shape.DebugString(), " ", accum_shape.DebugString()));
411     OP_REQUIRES(
412         ctx, var_shape.IsSameSize(squared_accum_shape),
413         errors::InvalidArgument(
414             "var and squared accum do not have the same shape",
415             var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
416 
417     TensorShape grad_shape = ctx->InputShape(3);
418     TensorShape lr_shape = ctx->InputShape(4);
419     TensorShape l1_shape = ctx->InputShape(5);
420     TensorShape l2_shape = ctx->InputShape(6);
421     TensorShape global_step_shape = ctx->InputShape(7);
422 
423     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
424                 errors::InvalidArgument(
425                     "var and grad do not have the same shape",
426                     var_shape.DebugString(), " ", grad_shape.DebugString()));
427     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
428                 errors::InvalidArgument("lr is not a scalar: ",
429                                         lr_shape.DebugString()));
430     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
431                 errors::InvalidArgument("l1 is not a scalar: ",
432                                         l1_shape.DebugString()));
433     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
434                 errors::InvalidArgument("l2 is not a scalar: ",
435                                         l2_shape.DebugString()));
436     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
437                 errors::InvalidArgument("global step is not a scalar: ",
438                                         global_step_shape.DebugString()));
439 
440     xla::XlaOp grad = ctx->Input(3);
441     xla::XlaOp lr = ctx->Input(4);
442     xla::XlaOp l1 = ctx->Input(5);
443     xla::XlaOp l2 = ctx->Input(6);
444     xla::XlaOp global_step =
445         XlaHelpers::ConvertElementType(ctx->Input(7), dtype_);
446 
447     accum = accum + grad;
448     squared_accum = squared_accum + xla::Square(grad);
449     xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
450     xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
451     xla::XlaOp l1_le_zero = -lr * accum / denominator;
452     xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
453                             xla::Max(xla::Abs(accum) - global_step * l1, zero) /
454                             denominator;
455 
456     var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
457     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
458     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
459     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
460   }
461 
462  private:
463   DataType dtype_;
464 };
465 REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
466                 ResourceApplyAdagradDA);
467 
468 class ResourceApplyAdam : public XlaOpKernel {
469  public:
ResourceApplyAdam(OpKernelConstruction * ctx)470   explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
471     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
472     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
473   }
474 
Compile(XlaOpKernelContext * ctx)475   void Compile(XlaOpKernelContext* ctx) override {
476     TensorShape var_shape, m_shape, v_shape;
477     xla::XlaOp var, m, v;
478     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
479     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
480     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
481 
482     TensorShape beta1_power_shape = ctx->InputShape(3);
483     TensorShape beta2_power_shape = ctx->InputShape(4);
484     TensorShape lr_shape = ctx->InputShape(5);
485     TensorShape beta1_shape = ctx->InputShape(6);
486     TensorShape beta2_shape = ctx->InputShape(7);
487     TensorShape epsilon_shape = ctx->InputShape(8);
488     TensorShape grad_shape = ctx->InputShape(9);
489 
490     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
491                 errors::InvalidArgument("beta1_power is not a scalar: ",
492                                         beta1_power_shape.DebugString()));
493     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape),
494                 errors::InvalidArgument("beta2_power is not a scalar: ",
495                                         beta2_power_shape.DebugString()));
496     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
497                 errors::InvalidArgument("lr is not a scalar : ",
498                                         lr_shape.DebugString()));
499     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
500                 errors::InvalidArgument("beta1 is not a scalar: ",
501                                         beta1_shape.DebugString()));
502     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
503                 errors::InvalidArgument("beta2 is not a scalar: ",
504                                         beta2_shape.DebugString()));
505     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
506                 errors::InvalidArgument("epsilon is not a scalar: ",
507                                         epsilon_shape.DebugString()));
508 
509     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
510                 errors::InvalidArgument("var and m do not have the same shape",
511                                         var_shape.DebugString(), " ",
512                                         m_shape.DebugString()));
513     OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
514                 errors::InvalidArgument("var and v do not have the same shape",
515                                         var_shape.DebugString(), " ",
516                                         v_shape.DebugString()));
517     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
518                 errors::InvalidArgument(
519                     "var and grad do not have the same shape",
520                     var_shape.DebugString(), " ", grad_shape.DebugString()));
521 
522     xla::XlaOp beta1_power = ctx->Input(3);
523     xla::XlaOp beta2_power = ctx->Input(4);
524     xla::XlaOp lr = ctx->Input(5);
525     xla::XlaOp beta1 = ctx->Input(6);
526     xla::XlaOp beta2 = ctx->Input(7);
527     xla::XlaOp epsilon = ctx->Input(8);
528     xla::XlaOp grad = ctx->Input(9);
529 
530     // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
531     // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
532     // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
533     // if not use_nesterov:
534     //   variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
535     // if use_nesterov:
536     //   variable <- variable - alpha * (m_t * beta1 + (1 - beta1) * g_t) /
537     //   (sqrt(v_t) + epsilon)
538 
539     xla::XlaBuilder* b = ctx->builder();
540     xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
541 
542     xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
543     auto m_t = m + (grad - m) * (one - beta1);
544     v = v + (xla::Square(grad) - v) * (one - beta2);
545     if (use_nesterov_) {
546       var = var - alpha * (m_t * beta1 + (one - beta1) * grad) /
547                       (xla::Sqrt(v) + epsilon);
548     } else {
549       var = var - m_t * alpha / (xla::Sqrt(v) + epsilon);
550     }
551 
552     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
553     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m_t));
554     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
555   }
556 
557  private:
558   DataType dtype_;
559   bool use_nesterov_;
560 };
561 REGISTER_XLA_OP(
562     Name("ResourceApplyAdam").TypeConstraint("T", kFloatAndComplexTypes),
563     ResourceApplyAdam);
564 
565 class ResourceApplyAdaMax : public XlaOpKernel {
566  public:
ResourceApplyAdaMax(OpKernelConstruction * ctx)567   explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
568     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
569   }
570 
Compile(XlaOpKernelContext * ctx)571   void Compile(XlaOpKernelContext* ctx) override {
572     TensorShape var_shape, m_shape, v_shape;
573     xla::XlaOp var, m, v;
574     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
575     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
576     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
577 
578     TensorShape beta1_power_shape = ctx->InputShape(3);
579     TensorShape lr_shape = ctx->InputShape(4);
580     TensorShape beta1_shape = ctx->InputShape(5);
581     TensorShape beta2_shape = ctx->InputShape(6);
582     TensorShape epsilon_shape = ctx->InputShape(7);
583     TensorShape grad_shape = ctx->InputShape(8);
584 
585     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
586                 errors::InvalidArgument("beta1_power is not a scalar: ",
587                                         beta1_power_shape.DebugString()));
588     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
589                 errors::InvalidArgument("lr is not a scalar : ",
590                                         lr_shape.DebugString()));
591     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
592                 errors::InvalidArgument("beta1 is not a scalar: ",
593                                         beta1_shape.DebugString()));
594     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
595                 errors::InvalidArgument("beta2 is not a scalar: ",
596                                         beta2_shape.DebugString()));
597     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
598                 errors::InvalidArgument("epsilon is not a scalar: ",
599                                         epsilon_shape.DebugString()));
600     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
601                 errors::InvalidArgument("var and m do not have the same shape",
602                                         var_shape.DebugString(), " ",
603                                         m_shape.DebugString()));
604     OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
605                 errors::InvalidArgument("var and v do not have the same shape",
606                                         var_shape.DebugString(), " ",
607                                         v_shape.DebugString()));
608     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
609                 errors::InvalidArgument(
610                     "var and grad do not have the same shape",
611                     var_shape.DebugString(), " ", grad_shape.DebugString()));
612 
613     xla::XlaOp beta1_power = ctx->Input(3);
614     xla::XlaOp lr = ctx->Input(4);
615     xla::XlaOp beta1 = ctx->Input(5);
616     xla::XlaOp beta2 = ctx->Input(6);
617     xla::XlaOp epsilon = ctx->Input(7);
618     xla::XlaOp grad = ctx->Input(8);
619 
620     xla::XlaOp one = xla::ScalarLike(lr, 1.0);
621     m = beta1 * m + (one - beta1) * grad;
622     v = xla::Max(beta2 * v, xla::Abs(grad));
623     var = var - lr / (one - beta1_power) * (m / (v + epsilon));
624 
625     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
626     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
627     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
628   }
629 
630  private:
631   DataType dtype_;
632 };
633 REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
634                 ResourceApplyAdaMax);
635 
636 class ResourceApplyRMSProp : public XlaOpKernel {
637  public:
ResourceApplyRMSProp(OpKernelConstruction * ctx)638   explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
639     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
640   }
641 
Compile(XlaOpKernelContext * ctx)642   void Compile(XlaOpKernelContext* ctx) override {
643     TensorShape var_shape, ms_shape, mom_shape, mg_shape;
644     xla::XlaOp var, ms, mom, mg;
645     OP_REQUIRES_OK(ctx,
646                    ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
647     if (centered_) {
648       OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
649     }
650     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
651     OP_REQUIRES_OK(ctx,
652                    ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
653 
654     TensorShape lr_shape = ctx->InputShape("lr");
655     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
656                 errors::InvalidArgument("lr is not a scalar: ",
657                                         lr_shape.DebugString()));
658     TensorShape rho_shape = ctx->InputShape("rho");
659     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
660                 errors::InvalidArgument("rho is not a scalar: ",
661                                         rho_shape.DebugString()));
662     TensorShape momentum_shape = ctx->InputShape("momentum");
663     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
664                 errors::InvalidArgument("momentum is not a scalar: ",
665                                         momentum_shape.DebugString()));
666     TensorShape epsilon_shape = ctx->InputShape("epsilon");
667     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
668                 errors::InvalidArgument("epsilon is not a scalar: ",
669                                         epsilon_shape.DebugString()));
670     TensorShape grad_shape = ctx->InputShape("grad");
671 
672     // var should be the same shape as mom and ms.
673     OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
674                 errors::InvalidArgument("var and ms do not have the same shape",
675                                         var_shape.DebugString(), " ",
676                                         ms_shape.DebugString()));
677     OP_REQUIRES(ctx, var_shape.IsSameSize(mom_shape),
678                 errors::InvalidArgument(
679                     "var and mom do not have the same shape",
680                     var_shape.DebugString(), " ", mom_shape.DebugString()));
681     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
682                 errors::InvalidArgument(
683                     "var and grad do not have the same shape",
684                     var_shape.DebugString(), " ", grad_shape.DebugString()));
685 
686     xla::XlaOp lr = ctx->Input("lr");
687     xla::XlaOp rho = ctx->Input("rho");
688     xla::XlaOp momentum = ctx->Input("momentum");
689     xla::XlaOp epsilon = ctx->Input("epsilon");
690     xla::XlaOp grad = ctx->Input("grad");
691 
692     // ms <- rho * ms_{t-1} + (1-rho) * grad * grad
693     // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
694     // var <- var - mom
695     //
696     // We use an alternate formulation of the ms equation:
697     //
698     //    ms <- ms + (grad**2 - ms) * (1 - rho)
699     //
700     // Which expands to:
701     //
702     //    ms <- ms + grad**2 - rho * grad ** 2 - ms + ms * rho
703     //
704     // Which simplifies to:
705     //
706     //    ms <- grad**2 (1 - rho) + ms * rho
707     //
708     // Which is the equation listed above.
709     xla::XlaOp one = xla::ScalarLike(ms, 1.0);
710     xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
711     xla::XlaOp denominator;
712     if (centered_) {
713       mg = grad * (one - rho) + mg * rho;
714       denominator = new_ms - xla::Square(mg) + epsilon;
715     } else {
716       denominator = new_ms + epsilon;
717     }
718     xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
719     xla::XlaOp new_var = var - new_mom;
720 
721     OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
722     if (centered_) {
723       OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
724     }
725     OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
726     OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
727   }
728 
729  protected:
730   bool centered_ = false;
731 
732  private:
733   DataType dtype_;
734 };
735 REGISTER_XLA_OP(
736     Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatAndComplexTypes),
737     ResourceApplyRMSProp);
738 
739 class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
740  public:
ResourceApplyCenteredRMSProp(OpKernelConstruction * ctx)741   explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
742       : ResourceApplyRMSProp(ctx) {
743     centered_ = true;
744   }
745 };
746 REGISTER_XLA_OP(Name("ResourceApplyCenteredRMSProp")
747                     .TypeConstraint("T", kFloatAndComplexTypes),
748                 ResourceApplyCenteredRMSProp);
749 
CompileFtrl(XlaOpKernelContext * ctx,DataType dtype,bool has_l2_shrinkage,bool multiply_linear_by_lr)750 void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage,
751                  bool multiply_linear_by_lr) {
752   xla::XlaBuilder* b = ctx->builder();
753 
754   TensorShape var_shape, accum_shape, linear_shape;
755   xla::XlaOp var, accum, linear;
756   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
757   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
758   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
759 
760   OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
761               errors::InvalidArgument(
762                   "var and accum do not have the same shape",
763                   var_shape.DebugString(), " ", accum_shape.DebugString()));
764 
765   OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape),
766               errors::InvalidArgument(
767                   "var and linear do not have the same shape",
768                   var_shape.DebugString(), " ", linear_shape.DebugString()));
769 
770   TensorShape grad_shape = ctx->InputShape(3);
771   TensorShape lr_shape = ctx->InputShape(4);
772   TensorShape l1_shape = ctx->InputShape(5);
773   TensorShape l2_shape = ctx->InputShape(6);
774   TensorShape l2_shrinkage_shape;
775   TensorShape lr_power_shape;
776   if (has_l2_shrinkage) {
777     l2_shrinkage_shape = ctx->InputShape(7);
778     lr_power_shape = ctx->InputShape(8);
779   } else {
780     lr_power_shape = ctx->InputShape(7);
781   }
782 
783   OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
784               errors::InvalidArgument("var and grad do not have the same shape",
785                                       var_shape.DebugString(), " ",
786                                       grad_shape.DebugString()));
787 
788   OP_REQUIRES(
789       ctx, TensorShapeUtils::IsScalar(lr_shape),
790       errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString()));
791 
792   OP_REQUIRES(
793       ctx, TensorShapeUtils::IsScalar(l1_shape),
794       errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString()));
795 
796   OP_REQUIRES(
797       ctx, TensorShapeUtils::IsScalar(l2_shape),
798       errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString()));
799 
800   if (has_l2_shrinkage) {
801     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape),
802                 errors::InvalidArgument("l2_shrinkage is not a scalar: ",
803                                         l2_shrinkage_shape.DebugString()));
804   }
805 
806   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape),
807               errors::InvalidArgument("lr_power is not a scalar: ",
808                                       lr_power_shape.DebugString()));
809 
810   xla::XlaOp grad = ctx->Input(3);
811   xla::XlaOp lr = ctx->Input(4);
812   xla::XlaOp l1 = ctx->Input(5);
813   xla::XlaOp l2 = ctx->Input(6);
814   xla::XlaOp l2_shrinkage;
815   xla::XlaOp lr_power;
816   if (has_l2_shrinkage) {
817     l2_shrinkage = ctx->Input(7);
818     lr_power = ctx->Input(8);
819   } else {
820     lr_power = ctx->Input(7);
821   }
822 
823   // grad_to_use = grad + 2 * l2_shrinkage * var
824   // new_accum = accum + grad * grad
825   // linear += grad_to_use -
826   //     (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
827   // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
828   // linear_clipped = clamp linear in [-l1, l1]
829   // var = (linear_clipped - linear) / quadratic
830   // accum = new_accum
831 
832   xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
833   xla::XlaOp grad_to_use;
834   if (has_l2_shrinkage) {
835     grad_to_use = grad + two * l2_shrinkage * var;
836   } else {
837     grad_to_use = grad;
838   }
839 
840   xla::XlaOp new_accum = accum + xla::Square(grad);
841   xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
842   xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
843   if (multiply_linear_by_lr) {
844     linear =
845         linear + grad_to_use * lr - (new_accum_lr_pow - accum_lr_pow) * var;
846   } else {
847     linear =
848         linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
849   }
850   xla::XlaOp linear_clipped =
851       (multiply_linear_by_lr ? xla::Clamp(-l1 * lr, linear, l1 * lr)
852                              : xla::Clamp(-l1, linear, l1));
853   xla::XlaOp quadratic =
854       (multiply_linear_by_lr ? new_accum_lr_pow + two * l2 * lr
855                              : new_accum_lr_pow / lr + two * l2);
856   var = (linear_clipped - linear) / quadratic;
857   accum = new_accum;
858 
859   OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
860   OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum));
861   OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear));
862 }
863 
864 class ResourceApplyFtrl : public XlaOpKernel {
865  public:
ResourceApplyFtrl(OpKernelConstruction * ctx)866   explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
867     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
868     OP_REQUIRES_OK(
869         ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
870   }
871 
Compile(XlaOpKernelContext * ctx)872   void Compile(XlaOpKernelContext* ctx) override {
873     CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false,
874                 /*multiply_linear_by_lr=*/multiply_linear_by_lr_);
875   }
876 
877  private:
878   DataType dtype_;
879 
880   // Whether to keep the "linear" slot variable multiplied by the learning rate.
881   bool multiply_linear_by_lr_;
882 };
883 REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
884                 ResourceApplyFtrl);
885 
886 class ResourceApplyFtrlV2 : public XlaOpKernel {
887  public:
ResourceApplyFtrlV2(OpKernelConstruction * ctx)888   explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
889     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
890     OP_REQUIRES_OK(
891         ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
892   }
893 
Compile(XlaOpKernelContext * ctx)894   void Compile(XlaOpKernelContext* ctx) override {
895     CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true,
896                 /*multiply_linear_by_lr=*/multiply_linear_by_lr_);
897   }
898 
899  private:
900   DataType dtype_;
901 
902   // Whether to keep the "linear" slot variable multiplied by the learning rate.
903   bool multiply_linear_by_lr_;
904 };
905 REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
906                 ResourceApplyFtrlV2);
907 
908 class ResourceApplyAdadelta : public XlaOpKernel {
909  public:
ResourceApplyAdadelta(OpKernelConstruction * ctx)910   explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
911     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
912   }
913 
Compile(XlaOpKernelContext * ctx)914   void Compile(XlaOpKernelContext* ctx) override {
915     TensorShape var_shape, accum_shape, accum_update_shape;
916     xla::XlaOp var, accum, accum_update;
917     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
918     OP_REQUIRES_OK(ctx,
919                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
920     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
921                                                &accum_update));
922 
923     TensorShape lr_shape = ctx->InputShape(3);
924     TensorShape rho_shape = ctx->InputShape(4);
925     TensorShape epsilon_shape = ctx->InputShape(5);
926     TensorShape grad_shape = ctx->InputShape(6);
927 
928     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
929                 errors::InvalidArgument("lr is not a scalar: ",
930                                         lr_shape.DebugString()));
931 
932     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
933                 errors::InvalidArgument("rho is not a scalar: ",
934                                         rho_shape.DebugString()));
935 
936     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
937                 errors::InvalidArgument("epsilon is not a scalar: ",
938                                         epsilon_shape.DebugString()));
939 
940     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
941                 errors::InvalidArgument(
942                     "var and accum do not have the same shape",
943                     var_shape.DebugString(), " ", accum_shape.DebugString()));
944 
945     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
946                 errors::InvalidArgument(
947                     "var and grad do not have the same shape",
948                     var_shape.DebugString(), " ", grad_shape.DebugString()));
949 
950     xla::XlaOp lr = ctx->Input(3);
951     xla::XlaOp rho = ctx->Input(4);
952     xla::XlaOp epsilon = ctx->Input(5);
953     xla::XlaOp grad = ctx->Input(6);
954 
955     xla::XlaBuilder* b = ctx->builder();
956     xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
957 
958     accum = rho * accum + (one - rho) * xla::Square(grad);
959     xla::XlaOp update =
960         xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad;
961     accum_update = rho * accum_update + (one - rho) * xla::Square(update);
962     var = var - update * lr;
963     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
964     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
965     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
966   }
967 
968  private:
969   DataType dtype_;
970 };
971 REGISTER_XLA_OP(
972     Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatAndComplexTypes),
973     ResourceApplyAdadelta);
974 
975 class ResourceApplySignBase : public XlaOpKernel {
976  public:
ResourceApplySignBase(OpKernelConstruction * ctx)977   explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
978     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
979   }
980 
Compile(XlaOpKernelContext * ctx)981   void Compile(XlaOpKernelContext* ctx) override {
982     TensorShape var_shape, m_shape;
983     xla::XlaOp var, m;
984     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
985     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
986     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
987                 errors::InvalidArgument("var and m do not have the same shape",
988                                         var_shape.DebugString(), " ",
989                                         m_shape.DebugString()));
990     TensorShape grad_shape = ctx->InputShape(6);
991     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
992                 errors::InvalidArgument(
993                     "var and grad do not have the same shape",
994                     var_shape.DebugString(), " ", grad_shape.DebugString()));
995     CheckScalarParams(ctx);
996 
997     xla::XlaOp lr = ctx->Input(2);
998     xla::XlaOp alpha = ctx->Input(3);
999     xla::XlaOp sign_decay = ctx->Input(4);
1000     xla::XlaOp beta = ctx->Input(5);
1001     xla::XlaOp grad = ctx->Input(6);
1002 
1003     m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
1004     xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
1005 
1006     xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
1007     var = var - lr * grad_scale * grad;
1008     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
1009     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
1010   }
1011 
CheckScalarParams(XlaOpKernelContext * ctx)1012   virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
1013     TensorShape lr_shape = ctx->InputShape(2);
1014     TensorShape sign_decay_shape = ctx->InputShape(4);
1015     TensorShape beta_shape = ctx->InputShape(5);
1016 
1017     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
1018                 errors::InvalidArgument("lr is not a scalar: ",
1019                                         lr_shape.DebugString()));
1020 
1021     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
1022                 errors::InvalidArgument("sign_decay is not a scalar: ",
1023                                         sign_decay_shape.DebugString()));
1024 
1025     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
1026                 errors::InvalidArgument("beta is not a scalar: ",
1027                                         beta_shape.DebugString()));
1028   }
1029 
1030   virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
1031                                           xla::XlaOp decay) = 0;
1032 
1033  private:
1034   DataType dtype_;
1035 };
1036 
1037 class ResourceApplyAddSign : public ResourceApplySignBase {
1038  public:
ResourceApplyAddSign(OpKernelConstruction * ctx)1039   explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
1040       : ResourceApplySignBase(ctx) {}
1041 
CheckScalarParams(XlaOpKernelContext * ctx)1042   void CheckScalarParams(XlaOpKernelContext* ctx) override {
1043     ResourceApplySignBase::CheckScalarParams(ctx);
1044     TensorShape alpha_shape = ctx->InputShape(3);
1045     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
1046                 errors::InvalidArgument("alpha is not a scalar: ",
1047                                         alpha_shape.DebugString()));
1048   }
1049 
ComputeGradientScale(xla::XlaOp alpha,xla::XlaOp decay)1050   xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
1051     return alpha + decay;
1052   }
1053 };
1054 REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
1055                 ResourceApplyAddSign);
1056 
1057 class ResourceApplyPowerSign : public ResourceApplySignBase {
1058  public:
ResourceApplyPowerSign(OpKernelConstruction * ctx)1059   explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
1060       : ResourceApplySignBase(ctx) {}
1061 
CheckScalarParams(XlaOpKernelContext * ctx)1062   void CheckScalarParams(XlaOpKernelContext* ctx) override {
1063     ResourceApplySignBase::CheckScalarParams(ctx);
1064     TensorShape logbase_shape = ctx->InputShape(3);
1065     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
1066                 errors::InvalidArgument("logbase is not a scalar: ",
1067                                         logbase_shape.DebugString()));
1068   }
1069 
ComputeGradientScale(xla::XlaOp alpha,xla::XlaOp decay)1070   xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
1071     return xla::Exp(alpha * decay);
1072   }
1073 };
1074 REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
1075                 ResourceApplyPowerSign);
1076 
1077 }  // namespace
1078 }  // namespace tensorflow
1079