xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/optimizer_v2/optimizer_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Version 2 of class Optimizer."""
16# pylint: disable=g-bad-name
17
18import abc
19import contextlib
20import functools
21import warnings
22
23from tensorflow.python.distribute import central_storage_strategy
24from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
25from tensorflow.python.distribute import parameter_server_strategy
26from tensorflow.python.distribute import parameter_server_strategy_v2
27from tensorflow.python.distribute import values as ds_values
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.keras import backend
35from tensorflow.python.keras import initializers
36from tensorflow.python.keras.engine import base_layer_utils
37from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
38from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import layer_utils
41from tensorflow.python.keras.utils import tf_inspect
42from tensorflow.python.keras.utils import tf_utils
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import gen_resource_variable_ops
46from tensorflow.python.ops import gradients
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import variables as tf_variables
49from tensorflow.python.saved_model import revived_types
50from tensorflow.python.trackable import base as trackable
51from tensorflow.python.util import nest
52from tensorflow.python.util.tf_export import keras_export
53
54
55_DEFAULT_VALID_DTYPES = frozenset([
56    dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
57    dtypes.complex64, dtypes.complex128
58])
59
60
61def _deduplicate_indexed_slices(values, indices):
62  """Sums `values` associated with any non-unique `indices`.
63
64  Args:
65    values: A `Tensor` with rank >= 1.
66    indices: A one-dimensional integer `Tensor`, indexing into the first
67      dimension of `values` (as in an IndexedSlices object).
68
69  Returns:
70    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
71    de-duplicated version of `indices` and `summed_values` contains the sum of
72    `values` slices associated with each unique index.
73  """
74  unique_indices, new_index_positions = array_ops.unique(indices)
75  summed_values = math_ops.unsorted_segment_sum(
76      values, new_index_positions,
77      array_ops.shape(unique_indices)[0])
78  return (summed_values, unique_indices)
79
80
81class NullContextmanager(object):
82
83  def __init__(self, *args, **kwargs):
84    pass
85
86  def __enter__(self):
87    pass
88
89  def __exit__(self, type_arg, value_arg, traceback_arg):
90    return False  # False values do not suppress exceptions
91
92
93def name_scope_only_in_function_or_graph(name):
94  """Internal-only entry point for `name_scope*`.
95
96  Enters a compat.v1.name_scope only when in a function or graph,
97  not when running fully eagerly.
98
99  Args:
100    name: The name argument that is passed to the op function.
101
102  Returns:
103    `name_scope*` context manager.
104  """
105  if not context.executing_eagerly():
106    return ops.name_scope_v1(name)
107  else:
108    return NullContextmanager()
109
110
111@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta)
112class OptimizerV2(trackable.Trackable):
113  """Base class for Keras optimizers.
114
115  You should not use this class directly, but instead instantiate one of its
116  subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc.
117
118  ### Usage
119
120  ```python
121  # Create an optimizer with the desired parameters.
122  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
123  # `loss` is a callable that takes no argument and returns the value
124  # to minimize.
125  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
126  # In graph mode, returns op that minimizes the loss by updating the listed
127  # variables.
128  opt_op = opt.minimize(loss, var_list=[var1, var2])
129  opt_op.run()
130  # In eager mode, simply call minimize to update the list of variables.
131  opt.minimize(loss, var_list=[var1, var2])
132  ```
133
134  ### Usage in custom training loops
135
136  In Keras models, sometimes variables are created when the model is first
137  called, instead of construction time. Examples include 1) sequential models
138  without input shape pre-defined, or 2) subclassed models. Pass var_list as
139  callable in these cases.
140
141  Example:
142
143  ```python
144  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
145  model = tf.keras.Sequential()
146  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
147  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
148  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
149  var_list_fn = lambda: model.trainable_weights
150  for input, output in data:
151    opt.minimize(loss_fn, var_list_fn)
152  ```
153
154  ### Processing gradients before applying them
155
156  Calling `minimize()` takes care of both computing the gradients and
157  applying them to the variables.  If you want to process the gradients
158  before applying them you can instead use the optimizer in three steps:
159
160  1.  Compute the gradients with `tf.GradientTape`.
161  2.  Process the gradients as you wish.
162  3.  Apply the processed gradients with `apply_gradients()`.
163
164  Example:
165
166  ```python
167  # Create an optimizer.
168  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
169
170  # Compute the gradients for a list of variables.
171  with tf.GradientTape() as tape:
172    loss = <call_loss_function>
173  vars = <list_of_variables>
174  grads = tape.gradient(loss, vars)
175
176  # Process the gradients, for example cap them, etc.
177  # capped_grads = [MyCapper(g) for g in grads]
178  processed_grads = [process_gradient(g) for g in grads]
179
180  # Ask the optimizer to apply the processed gradients.
181  opt.apply_gradients(zip(processed_grads, var_list))
182  ```
183
184  ### Use with `tf.distribute.Strategy`
185
186  This optimizer class is `tf.distribute.Strategy` aware, which means it
187  automatically sums gradients across all replicas. To average gradients,
188  you divide your loss by the global batch size, which is done
189  automatically if you use `tf.keras` built-in training or evaluation loops.
190  See the `reduction` argument of your loss which should be set to
191  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
192  `tf.keras.losses.Reduction.SUM` for not.
193
194  To aggregate gradients yourself, call `apply_gradients` with
195  `experimental_aggregate_gradients` set to False. This is useful if you need to
196  process aggregated gradients.
197
198  If you are not using these and you want to average gradients, you should use
199  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
200  global batch size. Note that when using `tf.distribute.Strategy`, the first
201  component of a tensor's shape is the *replica-local* batch size, which is off
202  by a factor equal to the number of replicas being used to compute a single
203  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
204  resulting in gradients that can be many times too big.
205
206  ### Variable Constraints
207
208  All Keras optimizers respect variable constraints. If constraint function is
209  passed to any variable, the constraint will be applied to the variable after
210  the gradient has been applied to the variable.
211  Important: If gradient is sparse tensor, variable constraint is not supported.
212
213  ### Thread Compatibility
214
215  The entire optimizer is currently thread compatible, not thread-safe. The user
216  needs to perform synchronization if necessary.
217
218  ### Slots
219
220  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
221  additional variables associated with the variables to train.  These are called
222  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
223  the slots that it uses.  Once you have a slot name you can ask the optimizer
224  for the variable it created to hold the slot value.
225
226  This can be useful if you want to log debug a training algorithm, report stats
227  about the slots, etc.
228
229  ### Hyperparameters
230
231  These are arguments passed to the optimizer subclass constructor
232  (the `__init__` method), and then passed to `self._set_hyper()`.
233  They can be either regular Python values (like 1.0), tensors, or
234  callables. If they are callable, the callable will be called during
235  `apply_gradients()` to get the value for the hyper parameter.
236
237  Hyperparameters can be overwritten through user code:
238
239  Example:
240
241  ```python
242  # Create an optimizer with the desired parameters.
243  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
244  # `loss` is a callable that takes no argument and returns the value
245  # to minimize.
246  loss = lambda: 3 * var1 + 2 * var2
247  # In eager mode, simply call minimize to update the list of variables.
248  opt.minimize(loss, var_list=[var1, var2])
249  # update learning rate
250  opt.learning_rate = 0.05
251  opt.minimize(loss, var_list=[var1, var2])
252  ```
253
254  ### Callable learning rate
255
256  Optimizer accepts a callable learning rate in two ways. The first way is
257  through built-in or customized
258  `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
259  called on each iteration with `schedule(iteration)`, a `tf.Variable`
260  owned by the optimizer.
261
262  Example:
263
264  >>> var = tf.Variable(np.random.random(size=(1,)))
265  >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
266  ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
267  >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate)
268  >>> loss = lambda: 3 * var
269  >>> opt.minimize(loss, var_list=[var])
270  <tf.Variable...
271
272  The second way is through a callable function that
273  does not accept any arguments.
274
275  Example:
276
277  >>> var = tf.Variable(np.random.random(size=(1,)))
278  >>> def lr_callable():
279  ...   return .1
280  >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable)
281  >>> loss = lambda: 3 * var
282  >>> opt.minimize(loss, var_list=[var])
283  <tf.Variable...
284
285  ### Creating a custom optimizer
286
287  If you intend to create your own optimization algorithm, simply inherit from
288  this class and override the following methods:
289
290    - `_resource_apply_dense` (update variable given gradient tensor is a dense
291      `tf.Tensor`)
292    - `_resource_apply_sparse` (update variable given gradient tensor is a
293      sparse `tf.IndexedSlices`. The most common way for this to happen
294      is if you are taking the gradient through a `tf.gather`.)
295    - `_create_slots`
296      (if your optimizer algorithm requires additional variables)
297    - `get_config`
298      (serialization of the optimizer, include all hyper parameters)
299  """
300
301  # Subclasses should set this to True unless they override `apply_gradients`
302  # with a version that does not have the `experimental_aggregate_gradients`
303  # argument.  Older versions of Keras did not have this argument so custom
304  # optimizers may have overridden `apply_gradients` without the
305  # `experimental_aggregate_gradients` argument. Keras only passes
306  # `experimental_aggregate_gradients` if this attribute is True.
307  # Note: This attribute will likely be removed in an upcoming release.
308  _HAS_AGGREGATE_GRAD = False
309
310  def __init__(self,
311               name,
312               gradient_aggregator=None,
313               gradient_transformers=None,
314               **kwargs):
315    """Create a new Optimizer.
316
317    This must be called by the constructors of subclasses.
318    Note that Optimizer instances should not bind to a single graph,
319    and so shouldn't keep Tensors as member variables. Generally
320    you should be able to use the _set_hyper()/state.get_hyper()
321    facility instead.
322
323    This class is stateful and thread-compatible.
324
325    Example of custom gradient transformations:
326
327    ```python
328    def my_gradient_transformer(grads_and_vars):
329      # Simple example, double the gradients.
330      return [(2. * g, v) for g, v in grads_and_vars]
331
332    optimizer = tf.keras.optimizers.SGD(
333        1e-3, gradient_transformers=[my_gradient_transformer])
334    ```
335
336    Args:
337      name: String. The name to use for momentum accumulator weights created
338        by the optimizer.
339      gradient_aggregator: The function to use to aggregate gradients across
340        devices (when using `tf.distribute.Strategy`). If `None`, defaults to
341        summing the gradients across devices. The function should accept and
342        return a list of `(gradient, variable)` tuples.
343      gradient_transformers: Optional. List of functions to use to transform
344        gradients before applying updates to Variables. The functions are
345        applied after `gradient_aggregator`. The functions should accept and
346        return a list of `(gradient, variable)` tuples.
347      **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
348        `clipnorm`, `global_clipnorm`.
349        If `clipvalue` (float) is set, the gradient of each weight
350        is clipped to be no higher than this value.
351        If `clipnorm` (float) is set, the gradient of each weight
352        is individually clipped so that its norm is no higher than this value.
353        If `global_clipnorm` (float) is set the gradient of all weights is
354        clipped so that their global norm is no higher than this value.
355
356    Raises:
357      ValueError: in case of any invalid argument.
358    """
359    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"}
360    for k in kwargs:
361      if k not in allowed_kwargs:
362        raise TypeError("Unexpected keyword argument "
363                        "passed to optimizer: " + str(k))
364      # checks that all keyword arguments are non-negative.
365      if kwargs[k] is not None and kwargs[k] < 0:
366        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
367      if k == "lr":
368        warnings.warn(
369            "The `lr` argument is deprecated, use `learning_rate` instead.")
370
371    self._use_locking = True
372    self._init_set_name(name)
373    self._hyper = {}
374    # dict: {variable name : {slot name : variable}}
375    self._slots = {}
376    self._slot_names = []
377    self._weights = []
378    self._iterations = None
379
380    # For implementing Trackable. Stores information about how to restore
381    # slot variables which have not yet been created
382    # (trackable._CheckpointPosition objects).
383    #  {slot_name :
384    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
385    #   ... }
386    self._deferred_slot_restorations = {}
387
388    decay = kwargs.pop("decay", 0.0)
389    if decay < 0.:
390      raise ValueError("decay cannot be less than 0: {}".format(decay))
391    self._initial_decay = decay
392
393    self._hypers_created = False
394    # Store the distribution strategy object if the optimizer is created inside
395    # strategy scope, so it could be used to create variables later.
396    if distribute_ctx.has_strategy():
397      self._distribution_strategy = distribute_ctx.get_strategy()
398    else:
399      self._distribution_strategy = None
400
401    # Configure gradient transformations.
402    if gradient_aggregator is None:
403      gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
404    self.gradient_aggregator = gradient_aggregator
405    if gradient_transformers is None:
406      gradient_transformers = []
407    self.gradient_transformers = gradient_transformers
408    self.clipnorm = kwargs.pop("clipnorm", None)
409    self.global_clipnorm = kwargs.pop("global_clipnorm", None)
410    if self.clipnorm is not None and self.global_clipnorm is not None:
411      raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, "
412                       "passed `clipnorm` {}, `global_clipnorm` {}".format(
413                           self.clipnorm, self.global_clipnorm))
414    self.clipvalue = kwargs.pop("clipvalue", None)
415
416  @property
417  def clipnorm(self):
418    """`float` or `None`. If set, clips gradients to a maximum norm."""
419    return self._clipnorm
420
421  @property
422  def global_clipnorm(self):
423    """`float` or `None`. If set, clips gradients to a maximum norm."""
424    return self._global_clipnorm
425
426  @clipnorm.setter
427  def clipnorm(self, val):
428    if val is not None and self.gradient_transformers:
429      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
430                       "is set. Instead, use the `gradient_transformers` to "
431                       "specify clipping and other transformations.")
432    self._clipnorm = val
433    self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
434        self._clipnorm)
435
436  @global_clipnorm.setter
437  def global_clipnorm(self, val):
438    if val is not None and self.gradient_transformers:
439      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
440                       "is set. Instead, use the `gradient_transformers` to "
441                       "specify clipping and other transformations.")
442    self._global_clipnorm = val
443    self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn(
444        self._global_clipnorm)
445
446  @property
447  def clipvalue(self):
448    """`float` or `None`. If set, clips gradients to a maximum value."""
449    return self._clipvalue
450
451  @clipvalue.setter
452  def clipvalue(self, val):
453    if val is not None and self.gradient_transformers:
454      raise ValueError("`clipvalue` cannot be set when `gradient_transformers` "
455                       "is set. Instead, use the `gradient_transformers` to "
456                       "specify clipping and other transformations.")
457    self._clipvalue = val
458    self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
459        self._clipvalue)
460
461  def _transform_loss(self, loss):
462    """Called in `.minimize` to transform loss before computing gradients."""
463    return loss
464
465  def _get_gradients(self, tape, loss, var_list, grad_loss=None):
466    """Called in `minimize` to compute gradients from loss."""
467    grads = tape.gradient(loss, var_list, grad_loss)
468    return list(zip(grads, var_list))
469
470  def _transform_unaggregated_gradients(self, grads_and_vars):
471    """Called in `apply_gradients` before gradient aggregation."""
472    return grads_and_vars
473
474  def _aggregate_gradients(self, grads_and_vars):
475    """Called in `apply_gradients` to aggregate gradients across devices.
476
477    Note that user subclasses may override this, so the interface should not be
478    changed.
479
480    Args:
481      grads_and_vars: List of (gradient, variable) pairs.
482
483    Returns:
484      A list of (aggregrated_gradient, variable) pairs. By default, this calls
485      `self.gradient_aggregator`.
486    """
487    return self.gradient_aggregator(grads_and_vars)
488
489  def _transform_gradients(self, grads_and_vars):
490    """Called in `apply_gradients` after aggregation."""
491    if self._clipvalue is not None:
492      grads_and_vars = self._clipvalue_fn(grads_and_vars)
493    if self._clipnorm is not None:
494      grads_and_vars = self._clipnorm_fn(grads_and_vars)
495    if self._global_clipnorm is not None:
496      grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
497
498    for fn in self.gradient_transformers:
499      grads_and_vars = fn(grads_and_vars)
500    return grads_and_vars
501
502  def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
503    """Minimize `loss` by updating `var_list`.
504
505    This method simply computes gradient using `tf.GradientTape` and calls
506    `apply_gradients()`. If you want to process the gradient before applying
507    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
508    of using this function.
509
510    Args:
511      loss: `Tensor` or callable. If a callable, `loss` should take no arguments
512        and return the value to minimize. If a `Tensor`, the `tape` argument
513        must be passed.
514      var_list: list or tuple of `Variable` objects to update to minimize
515        `loss`, or a callable returning the list or tuple of `Variable` objects.
516        Use callable when the variable list would otherwise be incomplete before
517        `minimize` since the variables are created at the first time `loss` is
518        called.
519      grad_loss: (Optional). A `Tensor` holding the gradient computed for
520        `loss`.
521      name: (Optional) str. Name for the returned operation.
522      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
523        the tape that computed the `loss` must be provided.
524
525    Returns:
526      An `Operation` that updates the variables in `var_list`. The `iterations`
527      will be automatically increased by 1.
528
529    Raises:
530      ValueError: If some of the variables are not `Variable` objects.
531
532    """
533    grads_and_vars = self._compute_gradients(
534        loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
535    return self.apply_gradients(grads_and_vars, name=name)
536
537  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
538    """Compute gradients of `loss` for the variables in `var_list`.
539
540    This is the first part of `minimize()`.  It returns a list
541    of (gradient, variable) pairs where "gradient" is the gradient
542    for "variable".  Note that "gradient" can be a `Tensor`, an
543    `IndexedSlices`, or `None` if there is no gradient for the
544    given variable.
545
546    Args:
547      loss: `Tensor` or callable. If a callable, `loss` should take no
548        arguments and return the value to minimize. If a `Tensor`, the `tape`
549        argument must be passed.
550      var_list: list or tuple of `Variable` objects to update to minimize
551        `loss`, or a callable returning the list or tuple of `Variable` objects.
552        Use callable when the variable list would otherwise be incomplete before
553        `minimize` and the variables are created at the first time when `loss`
554        is called.
555      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
556      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
557        the tape that computed the `loss` must be provided.
558
559    Returns:
560      A list of (gradient, variable) pairs. Variable is always present, but
561      gradient can be `None`.
562
563    Raises:
564      TypeError: If `var_list` contains anything else than `Variable` objects.
565      ValueError: If some arguments are invalid, or var_list is None.
566    """
567    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
568    if not callable(loss) and tape is None:
569      raise ValueError("`tape` is required when a `Tensor` loss is passed.")
570    tape = tape if tape is not None else backprop.GradientTape()
571
572    if callable(loss):
573      with tape:
574        if not callable(var_list):
575          tape.watch(var_list)
576        loss = loss()
577        if callable(var_list):
578          var_list = var_list()
579
580    with tape:
581      loss = self._transform_loss(loss)
582
583    var_list = nest.flatten(var_list)
584    with ops.name_scope_v2(self._name + "/gradients"):
585      grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
586
587    self._assert_valid_dtypes([
588        v for g, v in grads_and_vars
589        if g is not None and v.dtype != dtypes.resource
590    ])
591
592    return grads_and_vars
593
594  def apply_gradients(self,
595                      grads_and_vars,
596                      name=None,
597                      experimental_aggregate_gradients=True):
598    """Apply gradients to variables.
599
600    This is the second part of `minimize()`. It returns an `Operation` that
601    applies gradients.
602
603    The method sums gradients from all replicas in the presence of
604    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
605    passing `experimental_aggregate_gradients=False`.
606
607    Example:
608
609    ```python
610    grads = tape.gradient(loss, vars)
611    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
612    # Processing aggregated gradients.
613    optimizer.apply_gradients(zip(grads, vars),
614        experimental_aggregate_gradients=False)
615
616    ```
617
618    Args:
619      grads_and_vars: List of (gradient, variable) pairs.
620      name: Optional name for the returned operation. Default to the name passed
621        to the `Optimizer` constructor.
622      experimental_aggregate_gradients: Whether to sum gradients from different
623        replicas in the presense of `tf.distribute.Strategy`. If False, it's
624        user responsibility to aggregate the gradients. Default to True.
625
626    Returns:
627      An `Operation` that applies the specified gradients. The `iterations`
628      will be automatically increased by 1.
629
630    Raises:
631      TypeError: If `grads_and_vars` is malformed.
632      ValueError: If none of the variables have gradients.
633      RuntimeError: If called in a cross-replica context.
634    """
635    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
636    var_list = [v for (_, v) in grads_and_vars]
637
638    with ops.name_scope_v2(self._name):
639      # Create iteration if necessary.
640      with ops.init_scope():
641        self._create_all_weights(var_list)
642
643      if not grads_and_vars:
644        # Distribution strategy does not support reducing an empty list of
645        # gradients
646        return control_flow_ops.no_op()
647
648      if distribute_ctx.in_cross_replica_context():
649        raise RuntimeError(
650            "`apply_gradients() cannot be called in cross-replica context. "
651            "Use `tf.distribute.Strategy.run` to enter replica "
652            "context.")
653
654      strategy = distribute_ctx.get_strategy()
655      if (not experimental_aggregate_gradients and strategy and
656          isinstance(strategy,
657                     (parameter_server_strategy.ParameterServerStrategyV1,
658                      parameter_server_strategy_v2.ParameterServerStrategyV2,
659                      central_storage_strategy.CentralStorageStrategy,
660                      central_storage_strategy.CentralStorageStrategyV1))):
661        raise NotImplementedError(
662            "`experimental_aggregate_gradients=False is not supported for "
663            "ParameterServerStrategy and CentralStorageStrategy")
664
665      apply_state = self._prepare(var_list)
666      if experimental_aggregate_gradients:
667        grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars)
668        grads_and_vars = self._aggregate_gradients(grads_and_vars)
669      grads_and_vars = self._transform_gradients(grads_and_vars)
670
671      if optimizer_utils.strategy_supports_no_merge_call():
672        return self._distributed_apply(strategy, grads_and_vars, name,
673                                       apply_state)
674      else:
675        return distribute_ctx.get_replica_context().merge_call(
676            functools.partial(self._distributed_apply, apply_state=apply_state),
677            args=(grads_and_vars,),
678            kwargs={
679                "name": name,
680            })
681
682  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
683    """`apply_gradients` using a `DistributionStrategy`."""
684
685    def apply_grad_to_update_var(var, grad):
686      """Apply gradient to variable."""
687      if isinstance(var, ops.Tensor):
688        raise NotImplementedError("Trying to update a Tensor ", var)
689
690      apply_kwargs = {}
691      if isinstance(grad, indexed_slices.IndexedSlices):
692        if var.constraint is not None:
693          raise RuntimeError(
694              "Cannot use a constraint function on a sparse variable.")
695        if "apply_state" in self._sparse_apply_args:
696          apply_kwargs["apply_state"] = apply_state
697        return self._resource_apply_sparse_duplicate_indices(
698            grad.values, var, grad.indices, **apply_kwargs)
699
700      if "apply_state" in self._dense_apply_args:
701        apply_kwargs["apply_state"] = apply_state
702      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
703      if var.constraint is not None:
704        with ops.control_dependencies([update_op]):
705          return var.assign(var.constraint(var))
706      else:
707        return update_op
708
709    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
710    update_ops = []
711    with name_scope_only_in_function_or_graph(name or self._name):
712      for grad, var in grads_and_vars:
713        # Colocate the update with variables to avoid unnecessary communication
714        # delays. See b/136304694.
715        with distribution.extended.colocate_vars_with(var):
716          with name_scope_only_in_function_or_graph(
717              "update" if eagerly_outside_functions else "update_" +
718              var.op.name):
719            update_op = distribution.extended.update(
720                var, apply_grad_to_update_var, args=(grad,), group=False)
721            if distribute_ctx.in_cross_replica_context():
722              # In cross-replica context, extended.update returns a list of
723              # update ops from all replicas (group=False).
724              update_ops.extend(update_op)
725            else:
726              # In replica context, extended.update return the single update op
727              # of current replica.
728              update_ops.append(update_op)
729
730      any_symbolic = any(isinstance(i, ops.Operation) or
731                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
732      if not context.executing_eagerly() or any_symbolic:
733        # If the current context is graph mode or any of the update ops are
734        # symbolic then the step update should be carried out under a graph
735        # context. (eager updates execute immediately)
736        with backend._current_graph(update_ops).as_default():  # pylint: disable=protected-access
737          with ops.control_dependencies([control_flow_ops.group(update_ops)]):
738            return self._iterations.assign_add(1, read_value=False)
739
740      return self._iterations.assign_add(1)
741
742  def get_gradients(self, loss, params):
743    """Returns gradients of `loss` with respect to `params`.
744
745    Should be used only in legacy v1 graph mode.
746
747    Args:
748      loss: Loss tensor.
749      params: List of variables.
750
751    Returns:
752      List of gradient tensors.
753
754    Raises:
755      ValueError: In case any gradient cannot be computed (e.g. if gradient
756        function not implemented).
757    """
758    params = nest.flatten(params)
759    with backend.get_graph().as_default(), backend.name_scope(self._name +
760                                                              "/gradients"):
761      grads = gradients.gradients(loss, params)
762      for grad, param in zip(grads, params):
763        if grad is None:
764          raise ValueError("Variable {} has `None` for gradient. "
765                           "Please make sure that all of your ops have a "
766                           "gradient defined (i.e. are differentiable). "
767                           "Common ops without gradient: "
768                           "K.argmax, K.round, K.eval.".format(param))
769    return grads
770
771  def get_updates(self, loss, params):
772    grads = self.get_gradients(loss, params)
773    grads_and_vars = list(zip(grads, params))
774    self._assert_valid_dtypes([
775        v for g, v in grads_and_vars
776        if g is not None and v.dtype != dtypes.resource
777    ])
778    return [self.apply_gradients(grads_and_vars)]
779
780  def _set_hyper(self, name, value):
781    """set hyper `name` to value. value can be callable, tensor, numeric."""
782    if isinstance(value, trackable.Trackable):
783      self._track_trackable(value, name, overwrite=True)
784    if name not in self._hyper:
785      self._hyper[name] = value
786    else:
787      prev_value = self._hyper[name]
788      if (callable(prev_value)
789          or isinstance(prev_value,
790                        (ops.Tensor, int, float,
791                         learning_rate_schedule.LearningRateSchedule))
792          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
793        self._hyper[name] = value
794      else:
795        backend.set_value(self._hyper[name], value)
796
797  def _get_hyper(self, name, dtype=None):
798    if not self._hypers_created:
799      self._create_hypers()
800    value = self._hyper[name]
801    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
802      return value
803    if callable(value):
804      value = value()
805    if dtype:
806      return math_ops.cast(value, dtype)
807    else:
808      return value
809
810  def _create_slots(self, var_list):
811    pass
812
813  def _create_all_weights(self, var_list):
814    """Creates all weights, including iterations, hyperparameters and slot vars.
815
816    This will add newly created variables to `optimizer.weights`.
817
818    New variables are only created when this method is called the first time, or
819    when called with different variables in the var_list.
820
821    Args:
822      var_list: list or tuple of `Variable` objects that will be minimized
823        using this optimizer.
824    """
825
826    _ = self.iterations
827    self._create_hypers()
828    self._create_slots(var_list)
829
830  def __getattribute__(self, name):
831    """Overridden to support hyperparameter access."""
832    try:
833      return super(OptimizerV2, self).__getattribute__(name)
834    except AttributeError as e:
835      # Needed to avoid infinite recursion with __setattr__.
836      if name == "_hyper":
837        raise e
838      # Backwards compatibility with Keras optimizers.
839      if name == "lr":
840        name = "learning_rate"
841      if name in self._hyper:
842        return self._get_hyper(name)
843      raise e
844
845  def __dir__(self):
846    result = set(super(OptimizerV2, self).__dir__())
847    if "_hyper" in result:
848      result |= self._hyper.keys()
849      if "learning_rate" in self._hyper.keys():
850        result.add("lr")
851    return list(result)
852
853  def __setattr__(self, name, value):
854    """Override setattr to support dynamic hyperparameter setting."""
855    # Backwards compatibility with Keras optimizers.
856    if name == "lr":
857      name = "learning_rate"
858    if hasattr(self, "_hyper") and name in self._hyper:
859      self._set_hyper(name, value)
860    else:
861      super(OptimizerV2, self).__setattr__(name, value)
862
863  def get_slot_names(self):
864    """A list of names for this optimizer's slots."""
865    return self._slot_names
866
867  def add_slot(self, var, slot_name, initializer="zeros", shape=None):
868    """Add a new slot variable for `var`.
869
870    A slot variable is an additional variable associated with `var` to train.
871    It is allocated and managed by optimizers, e.g. `Adam`.
872
873    Args:
874      var: a `Variable` object.
875      slot_name: name of the slot variable.
876      initializer: initializer of the slot variable
877      shape: (Optional) shape of the slot variable. If not set, it will default
878      to the shape of `var`.
879
880    Returns:
881      A slot variable.
882    """
883    if slot_name not in self._slot_names:
884      self._slot_names.append(slot_name)
885    var_key = _var_key(var)
886    slot_dict = self._slots.setdefault(var_key, {})
887    weight = slot_dict.get(slot_name, None)
888    if weight is None:
889      if isinstance(initializer, str) or callable(initializer):
890        initializer = initializers.get(initializer)
891        if isinstance(
892            initializer,
893            trackable.CheckpointInitialValueCallable) or (shape is not None):
894          slot_shape = shape
895        else:
896          slot_shape = var.shape
897        initial_value = functools.partial(
898            initializer, shape=slot_shape, dtype=var.dtype)
899      else:
900        initial_value = initializer
901
902      with self._distribution_strategy_scope():
903        strategy = distribute_ctx.get_strategy()
904        if not strategy.extended.variable_created_in_scope(var):
905          raise ValueError(
906              "Trying to create optimizer slot variable under the scope for "
907              "tf.distribute.Strategy ({}), which is different from the scope "
908              "used for the original variable ({}). Make sure the slot "
909              "variables are created under the same strategy scope. This may "
910              "happen if you're restoring from a checkpoint outside the scope"
911              .format(strategy, var))
912
913        with strategy.extended.colocate_vars_with(var):
914          weight = tf_variables.Variable(
915              name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
916              dtype=var.dtype,
917              trainable=False,
918              initial_value=initial_value)
919      backend.track_variable(weight)
920      slot_dict[slot_name] = weight
921      self._restore_slot_variable(
922          slot_name=slot_name, variable=var,
923          slot_variable=weight)
924      self._weights.append(weight)
925    return weight
926
927  def get_slot(self, var, slot_name):
928    var_key = _var_key(var)
929    slot_dict = self._slots[var_key]
930    return slot_dict[slot_name]
931
932  def _prepare(self, var_list):
933    keys = set()
934    for var in var_list:
935      if isinstance(var, ds_values.DistributedValues):
936        var_devices = var._devices   # pylint: disable=protected-access
937      else:
938        var_devices = [var.device]
939      var_dtype = var.dtype.base_dtype
940      for var_device in var_devices:
941        keys.add((var_device, var_dtype))
942
943    apply_state = {}
944    for var_device, var_dtype in keys:
945      apply_state[(var_device, var_dtype)] = {}
946      with ops.device(var_device):
947        self._prepare_local(var_device, var_dtype, apply_state)
948
949    return apply_state
950
951  def _prepare_local(self, var_device, var_dtype, apply_state):
952    if "learning_rate" in self._hyper:
953      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
954      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
955
956  def _fallback_apply_state(self, var_device, var_dtype):
957    """Compatibility for subclasses that don't pass apply_state through."""
958    apply_state = {(var_device, var_dtype): {}}
959    self._prepare_local(var_device, var_dtype, apply_state)
960    return apply_state[(var_device, var_dtype)]
961
962  def _create_hypers(self):
963    if self._hypers_created:
964      return
965    with self._distribution_strategy_scope():
966      # Iterate hyper values deterministically.
967      for name, value in sorted(self._hyper.items()):
968        if isinstance(value,
969                      (ops.Tensor, tf_variables.Variable)) or callable(value):
970          # The check for `callable` covers the usage when `value` is a
971          # `LearningRateSchedule`, in which case it does not need to create a
972          # variable.
973          continue
974        else:
975          self._hyper[name] = self.add_weight(
976              name,
977              shape=[],
978              trainable=False,
979              initializer=value,
980              aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
981    self._hypers_created = True
982
983  @property
984  def iterations(self):
985    """Variable. The number of training steps this Optimizer has run."""
986    if self._iterations is None:
987      with self._distribution_strategy_scope():
988        self._iterations = self.add_weight(
989            "iter",
990            shape=[],
991            dtype=dtypes.int64,
992            trainable=False,
993            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
994      self._weights.append(self._iterations)
995    return self._iterations
996
997  @iterations.setter
998  def iterations(self, variable):
999    if self._iterations is not None:
1000      raise RuntimeError("Cannot set `iterations` to a new Variable after "
1001                         "the Optimizer weights have been created")
1002    self._iterations = variable
1003    self._weights.append(self._iterations)
1004
1005  def _decayed_lr(self, var_dtype):
1006    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
1007    lr_t = self._get_hyper("learning_rate", var_dtype)
1008    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
1009      local_step = math_ops.cast(self.iterations, var_dtype)
1010      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
1011    if self._initial_decay > 0.:
1012      local_step = math_ops.cast(self.iterations, var_dtype)
1013      decay_t = math_ops.cast(self._initial_decay, var_dtype)
1014      lr_t = lr_t / (1. + decay_t * local_step)
1015    return lr_t
1016
1017  @abc.abstractmethod
1018  def get_config(self):
1019    """Returns the config of the optimizer.
1020
1021    An optimizer config is a Python dictionary (serializable)
1022    containing the configuration of an optimizer.
1023    The same optimizer can be reinstantiated later
1024    (without any saved state) from this configuration.
1025
1026    Returns:
1027        Python dictionary.
1028    """
1029    config = {"name": self._name}
1030    if self.clipnorm is not None:
1031      config["clipnorm"] = self.clipnorm
1032    if self.clipvalue is not None:
1033      config["clipvalue"] = self.clipvalue
1034    if self.global_clipnorm is not None:
1035      config["global_clipnorm"] = self.global_clipnorm
1036    return config
1037
1038  @classmethod
1039  def from_config(cls, config, custom_objects=None):
1040    """Creates an optimizer from its config.
1041
1042    This method is the reverse of `get_config`,
1043    capable of instantiating the same optimizer from the config
1044    dictionary.
1045
1046    Args:
1047        config: A Python dictionary, typically the output of get_config.
1048        custom_objects: A Python dictionary mapping names to additional Python
1049          objects used to create this optimizer, such as a function used for a
1050          hyperparameter.
1051
1052    Returns:
1053        An optimizer instance.
1054    """
1055    if "lr" in config:
1056      config["learning_rate"] = config.pop("lr")
1057    if "learning_rate" in config:
1058      if isinstance(config["learning_rate"], dict):
1059        config["learning_rate"] = learning_rate_schedule.deserialize(
1060            config["learning_rate"], custom_objects=custom_objects)
1061    return cls(**config)
1062
1063  def _serialize_hyperparameter(self, hyperparameter_name):
1064    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
1065    value = self._hyper[hyperparameter_name]
1066    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
1067      return learning_rate_schedule.serialize(value)
1068    if callable(value):
1069      return value()
1070    if tensor_util.is_tf_type(value):
1071      return backend.get_value(value)
1072    return value
1073
1074  def variables(self):
1075    """Returns variables of this Optimizer based on the order created."""
1076    return self._weights
1077
1078  @property
1079  def weights(self):
1080    """Returns variables of this Optimizer based on the order created."""
1081    return self._weights
1082
1083  def get_weights(self):
1084    """Returns the current weights of the optimizer.
1085
1086    The weights of an optimizer are its state (ie, variables).
1087    This function returns the weight values associated with this
1088    optimizer as a list of Numpy arrays. The first value is always the
1089    iterations count of the optimizer, followed by the optimizer's state
1090    variables in the order they were created. The returned list can in turn
1091    be used to load state into similarly parameterized optimizers.
1092
1093    For example, the RMSprop optimizer for this simple model returns a list of
1094    three values-- the iteration count, followed by the root-mean-square value
1095    of the kernel and bias of the single Dense layer:
1096
1097    >>> opt = tf.keras.optimizers.RMSprop()
1098    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1099    >>> m.compile(opt, loss='mse')
1100    >>> data = np.arange(100).reshape(5, 20)
1101    >>> labels = np.zeros(5)
1102    >>> results = m.fit(data, labels)  # Training.
1103    >>> len(opt.get_weights())
1104    3
1105
1106    Returns:
1107        Weights values as a list of numpy arrays.
1108    """
1109    params = self.weights
1110    return backend.batch_get_value(params)
1111
1112  # TODO(tanzheny): Maybe share this logic with base_layer.
1113  def set_weights(self, weights):
1114    """Set the weights of the optimizer.
1115
1116    The weights of an optimizer are its state (ie, variables).
1117    This function takes the weight values associated with this
1118    optimizer as a list of Numpy arrays. The first value is always the
1119    iterations count of the optimizer, followed by the optimizer's state
1120    variables in the order they are created. The passed values are used to set
1121    the new state of the optimizer.
1122
1123    For example, the RMSprop optimizer for this simple model takes a list of
1124    three values-- the iteration count, followed by the root-mean-square value
1125    of the kernel and bias of the single Dense layer:
1126
1127    >>> opt = tf.keras.optimizers.RMSprop()
1128    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1129    >>> m.compile(opt, loss='mse')
1130    >>> data = np.arange(100).reshape(5, 20)
1131    >>> labels = np.zeros(5)
1132    >>> results = m.fit(data, labels)  # Training.
1133    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
1134    >>> opt.set_weights(new_weights)
1135    >>> opt.iterations
1136    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
1137
1138    Args:
1139        weights: weight values as a list of numpy arrays.
1140    """
1141    params = self.weights
1142    if len(params) != len(weights):
1143      raise ValueError(
1144          "You called `set_weights(weights)` on optimizer " + self._name +
1145          " with a  weight list of length " + str(len(weights)) +
1146          ", but the optimizer was expecting " + str(len(params)) +
1147          " weights. Provided weights: " + str(weights)[:50] + "...")
1148    if not params:
1149      return
1150    weight_value_tuples = []
1151    param_values = backend.batch_get_value(params)
1152    for pv, p, w in zip(param_values, params, weights):
1153      if pv.shape != w.shape:
1154        raise ValueError("Optimizer weight shape " + str(pv.shape) +
1155                         " not compatible with "
1156                         "provided weight shape " + str(w.shape))
1157      weight_value_tuples.append((p, w))
1158    backend.batch_set_value(weight_value_tuples)
1159
1160  def add_weight(self,
1161                 name,
1162                 shape,
1163                 dtype=None,
1164                 initializer="zeros",
1165                 trainable=None,
1166                 synchronization=tf_variables.VariableSynchronization.AUTO,
1167                 aggregation=tf_variables.VariableAggregation.NONE):
1168
1169    if dtype is None:
1170      dtype = dtypes.float32
1171    if isinstance(initializer, str) or callable(initializer):
1172      initializer = initializers.get(initializer)
1173
1174    if synchronization == tf_variables.VariableSynchronization.ON_READ:
1175      if trainable:
1176        raise ValueError(
1177            "Synchronization value can be set to "
1178            "VariableSynchronization.ON_READ only for non-trainable variables. "
1179            "You have specified trainable=True and "
1180            "synchronization=VariableSynchronization.ON_READ.")
1181      else:
1182        # Set trainable to be false when variable is to be synced on read.
1183        trainable = False
1184    elif trainable is None:
1185      trainable = True
1186
1187    variable = self._add_variable_with_custom_getter(
1188        name=name,
1189        shape=shape,
1190        getter=base_layer_utils.make_variable,
1191        overwrite=True,
1192        initializer=initializer,
1193        dtype=dtype,
1194        trainable=trainable,
1195        use_resource=True,
1196        synchronization=synchronization,
1197        aggregation=aggregation)
1198    backend.track_variable(variable)
1199
1200    return variable
1201
1202  def _init_set_name(self, name, zero_based=True):
1203    if not name:
1204      self._name = backend.unique_object_name(
1205          generic_utils.to_snake_case(self.__class__.__name__),
1206          zero_based=zero_based)
1207    else:
1208      self._name = name
1209
1210  def _assert_valid_dtypes(self, tensors):
1211    """Asserts tensors are all valid types (see `_valid_dtypes`).
1212
1213    Args:
1214      tensors: Tensors to check.
1215
1216    Raises:
1217      ValueError: If any tensor is not a valid type.
1218    """
1219    valid_dtypes = self._valid_dtypes()
1220    for t in tensors:
1221      dtype = t.dtype.base_dtype
1222      if dtype not in valid_dtypes:
1223        raise ValueError("Invalid type %r for %s, expected: %s." %
1224                         (dtype, t.name, [v for v in valid_dtypes]))
1225
1226  def _valid_dtypes(self):
1227    """Valid types for loss, variables and gradients.
1228
1229    Subclasses should override to allow other float types.
1230
1231    Returns:
1232      Valid types for loss, variables and gradients.
1233    """
1234    return _DEFAULT_VALID_DTYPES
1235
1236  def _call_if_callable(self, param):
1237    """Call the function if param is callable."""
1238    return param() if callable(param) else param
1239
1240  def _resource_apply_dense(self, grad, handle, apply_state):
1241    """Add ops to apply dense gradients to the variable `handle`.
1242
1243    Args:
1244      grad: a `Tensor` representing the gradient.
1245      handle: a `Tensor` of dtype `resource` which points to the variable to be
1246        updated.
1247      apply_state: A dict which is used across multiple apply calls.
1248
1249    Returns:
1250      An `Operation` which updates the value of the variable.
1251    """
1252    raise NotImplementedError("Must be implemented in subclasses.")
1253
1254  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
1255                                               **kwargs):
1256    """Add ops to apply sparse gradients to `handle`, with repeated indices.
1257
1258    Optimizers which override this method must deal with repeated indices. See
1259    the docstring of `_apply_sparse_duplicate_indices` for details. By default
1260    the correct behavior, to sum non-unique indices and their associated
1261    gradients, is enforced by first pre-processing `grad` and `indices` and
1262    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1263    with duplicate indices may instead override this method to avoid the
1264    overhead of summing.
1265
1266    Args:
1267      grad: a `Tensor` representing the gradient for the affected indices.
1268      handle: a `Tensor` of dtype `resource` which points to the variable to be
1269        updated.
1270      indices: a `Tensor` of integral type representing the indices for which
1271        the gradient is nonzero. Indices may be repeated.
1272      **kwargs: May optionally contain `apply_state`
1273
1274    Returns:
1275      An `Operation` which updates the value of the variable.
1276    """
1277    summed_grad, unique_indices = _deduplicate_indexed_slices(
1278        values=grad, indices=indices)
1279    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
1280                                       **kwargs)
1281
1282  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1283    """Add ops to apply sparse gradients to the variable `handle`.
1284
1285    Similar to `_apply_sparse`, the `indices` argument to this method has been
1286    de-duplicated. Optimizers which deal correctly with non-unique indices may
1287    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1288    overhead.
1289
1290    Args:
1291      grad: a `Tensor` representing the gradient for the affected indices.
1292      handle: a `Tensor` of dtype `resource` which points to the variable to be
1293        updated.
1294      indices: a `Tensor` of integral type representing the indices for which
1295        the gradient is nonzero. Indices are unique.
1296      apply_state: A dict which is used across multiple apply calls.
1297
1298    Returns:
1299      An `Operation` which updates the value of the variable.
1300    """
1301    raise NotImplementedError("Must be implemented in subclasses.")
1302
1303  def _resource_scatter_add(self, x, i, v):
1304    with ops.control_dependencies([
1305        gen_resource_variable_ops.ResourceScatterAdd(
1306            resource=x.handle, indices=i, updates=v)
1307    ]):
1308      return x.value()
1309
1310  def _resource_scatter_update(self, x, i, v):
1311    with ops.control_dependencies(
1312        [gen_resource_variable_ops.ResourceScatterUpdate(
1313            resource=x.handle, indices=i, updates=v)]):
1314      return x.value()
1315
1316  @property
1317  @layer_utils.cached_per_instance
1318  def _dense_apply_args(self):
1319    return tf_inspect.getfullargspec(self._resource_apply_dense).args
1320
1321  @property
1322  @layer_utils.cached_per_instance
1323  def _sparse_apply_args(self):
1324    return tf_inspect.getfullargspec(self._resource_apply_sparse).args
1325
1326  # ---------------
1327  # For implementing the trackable interface
1328  # ---------------
1329
1330  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1331    """Restore a newly created slot variable's value."""
1332    variable_key = _var_key(variable)
1333    deferred_restorations = self._deferred_slot_restorations.get(
1334        slot_name, {}).pop(variable_key, [])
1335    # Iterate over restores, highest restore UID first to minimize the number
1336    # of assignments.
1337    deferred_restorations.sort(key=lambda position: position.restore_uid,
1338                               reverse=True)
1339    for checkpoint_position in deferred_restorations:
1340      checkpoint_position.restore(slot_variable)
1341
1342  def _create_or_restore_slot_variable(
1343      self, slot_variable_position, slot_name, variable):
1344    """Restore a slot variable's value, possibly creating it.
1345
1346    Called when a variable which has an associated slot variable is created or
1347    restored. When executing eagerly, we create the slot variable with a
1348    restoring initializer.
1349
1350    No new variables are created when graph building. Instead,
1351    _restore_slot_variable catches these after normal creation and adds restore
1352    ops to the graph. This method is nonetheless important when graph building
1353    for the case when a slot variable has already been created but `variable`
1354    has just been added to a dependency graph (causing us to realize that the
1355    slot variable needs to be restored).
1356
1357    Args:
1358      slot_variable_position: A `trackable._CheckpointPosition` object
1359        indicating the slot variable `Trackable` object to be restored.
1360      slot_name: The name of this `Optimizer`'s slot to restore into.
1361      variable: The variable object this slot is being created for.
1362    """
1363    variable_key = _var_key(variable)
1364    slot_dict = self._slots.get(variable_key, {})
1365    slot_variable = slot_dict.get(slot_name, None)
1366    if (slot_variable is None and context.executing_eagerly() and
1367        slot_variable_position.is_simple_variable()
1368        # Defer slot variable creation if there is an active variable creator
1369        # scope. Generally we'd like to eagerly create/restore slot variables
1370        # when possible, but this may mean that scopes intended to catch
1371        # `variable` also catch its eagerly created slot variable
1372        # unintentionally (specifically make_template would add a dependency on
1373        # a slot variable if not for this case). Deferring is mostly harmless
1374        # (aside from double initialization), and makes variable creator scopes
1375        # behave the same way they do when graph building.
1376        #
1377        # One notable case is with distribution strategy, which uses variable
1378        # creator scope but always desires the `variable` and the slot to use
1379        # the same scope, thus we can safely eagerly create/restore slot
1380        # variables.
1381        and (not ops.get_default_graph()._variable_creator_stack or  # pylint: disable=protected-access
1382             self._distribution_strategy)):
1383      initializer = trackable.CheckpointInitialValueCallable(
1384          checkpoint_position=slot_variable_position)
1385      slot_variable = self.add_slot(
1386          var=variable,
1387          initializer=initializer,
1388          slot_name=slot_name,
1389          shape=slot_variable_position.value_shape())
1390      # Slot variables are not owned by any one object (because we don't want to
1391      # save the slot variable if the optimizer is saved without the non-slot
1392      # variable, or if the non-slot variable is saved without the optimizer;
1393      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1394      # variable, variable)). So we don't _track_ slot variables anywhere, and
1395      # instead special-case this dependency and otherwise pretend it's a normal
1396      # graph.
1397    if slot_variable is not None:
1398      # If we've either made this slot variable, or if we've pulled out an
1399      # existing slot variable, we should restore it.
1400      slot_variable_position.restore(slot_variable)
1401    else:
1402      # We didn't make the slot variable. Defer restoring until it gets created
1403      # normally. We keep a list rather than the one with the highest restore
1404      # UID in case slot variables have their own dependencies, in which case
1405      # those could differ between restores.
1406      self._deferred_slot_restorations.setdefault(
1407          slot_name, {}).setdefault(variable_key, []).append(
1408              slot_variable_position)
1409
1410  @contextlib.contextmanager
1411  def _distribution_strategy_scope(self):
1412    """Returns the `tf.distribute.Strategy` this optimizer was created under."""
1413    if self._distribution_strategy and not distribute_ctx.has_strategy():
1414      with self._distribution_strategy.scope():
1415        yield self._distribution_strategy.scope()
1416    else:
1417      yield
1418
1419
1420def _var_key(var):
1421  """Key for representing a primary variable, for looking up slots.
1422
1423  In graph mode the name is derived from the var shared name.
1424  In eager mode the name is derived from the var unique id.
1425  If distribution strategy exists, get the primary variable first.
1426
1427  Args:
1428    var: the variable.
1429
1430  Returns:
1431    the unique name of the variable.
1432  """
1433
1434  # pylint: disable=protected-access
1435  # Get the distributed variable if it exists.
1436  if hasattr(var, "_distributed_container"):
1437    var = var._distributed_container()
1438  if var._in_graph_mode:
1439    return var._shared_name
1440  return var._unique_id
1441
1442
1443def _get_slot_key_from_var(var, slot_name):
1444  """Get the slot key for the variable: var_name/slot_name."""
1445
1446  name = _var_key(var)
1447  return name + "/" + slot_name
1448
1449
1450class RestoredOptimizer(OptimizerV2):
1451  """A non-functional Optimizer implementation for checkpoint compatibility.
1452
1453  Holds slot variables and hyperparameters when an optimizer is restored from a
1454  SavedModel. These variables may be referenced in functions along with ops
1455  created by the original optimizer, but currently we do not support using the
1456  optimizer object iself (e.g. through `apply_gradients`).
1457  """
1458  # TODO(allenl): Make the restored optimizer functional by tracing its apply
1459  # methods.
1460
1461  def __init__(self):
1462    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1463    self._hypers_created = True
1464
1465  def get_config(self):
1466    # TODO(allenl): Save and restore the Optimizer's config
1467    raise NotImplementedError(
1468        "Restoring functional Optimizers from SavedModels is not currently "
1469        "supported. Please file a feature request if this limitation bothers "
1470        "you.")
1471
1472revived_types.register_revived_type(
1473    "tf_deprecated_optimizer",
1474    lambda obj: isinstance(obj, OptimizerV2),
1475    versions=[revived_types.VersionedTypeRegistration(
1476        object_factory=lambda proto: RestoredOptimizer(),
1477        version=1,
1478        min_producer_version=1,
1479        min_consumer_version=1,
1480        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1481    )])
1482