xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the loss scaling optimizer class."""
16
17from tensorflow.python.distribute import collective_all_reduce_strategy
18from tensorflow.python.distribute import distribution_strategy_context
19from tensorflow.python.distribute import mirrored_strategy
20from tensorflow.python.distribute import one_device_strategy
21from tensorflow.python.distribute import tpu_strategy
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import smart_cond
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import optimizers
30from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
31from tensorflow.python.keras.optimizer_v2 import optimizer_v2
32from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging
38from tensorflow.python.trackable import base as trackable
39from tensorflow.python.trackable import base_delegate
40from tensorflow.python.training.experimental import loss_scale as loss_scale_module
41from tensorflow.python.training.experimental import mixed_precision
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import keras_export
44
45
46class _UnwrapPreventer(object):
47  """Wrapper that DistributionStrategy will not unwrap.
48
49  Typically, DistributionStrategy will unwrap values when going from a cross-
50  replica context to a replica context via `call_for_each_replica`. This class
51  is a wrapper that DistributionStrategy will not unwrap, so it can be used to
52  prevent it from unwrapping a value.
53
54  TODO(reedwm): Find/implement a better way of preventing values from being
55  unwrapped by DistributionStrategy
56  """
57
58  __slots__ = ['value']
59
60  def __init__(self, value):
61    self.value = value
62
63
64def _is_all_finite(grads):
65  """Returns a scalar boolean tensor indicating if all gradients are finite."""
66  is_finite_per_grad = [
67      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
68  ]
69  return math_ops.reduce_all(is_finite_per_grad)
70
71
72def _op_in_graph_mode(tensor):
73  """Returns the tensor's op in graph mode, or the tensor in eager mode.
74
75  This is useful because sometimes an op is needed in graph mode instead of a
76  tensor. In eager mode, there are no ops.
77
78  Args:
79    tensor: A tensor.
80
81  Returns:
82    The tensor's op in graph mode. The tensor in eager mode.
83  """
84  if context.executing_eagerly():
85    return tensor
86  return tensor.op
87
88
89def _assign_if_finite(var, value):
90  """Assigns a value to a variable if the value is finite."""
91  return control_flow_ops.cond(
92      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
93      control_flow_ops.no_op)
94
95
96class _DynamicLossScaleState(trackable.Trackable):
97  """The state of a dynamic loss scale."""
98
99  def __init__(self,
100               initial_loss_scale,
101               growth_steps,
102               multiplier):
103    """Creates the dynamic loss scale."""
104    super(_DynamicLossScaleState, self).__init__()
105    self._initial_loss_scale = float(initial_loss_scale)
106    self._growth_steps = int(growth_steps)
107    self._multiplier = float(multiplier)
108
109    self._weights = {}
110    self._current_loss_scale = self._add_weight(
111        name='current_loss_scale',
112        dtype=dtypes.float32,
113        initial_value=self._initial_loss_scale)
114    # The number of consecutive steps with finite gradients since the last
115    # nonfinite gradient or change in loss scale. The name is 'good_steps' for
116    # backwards compatibility with older checkpoints.
117    self._counter = self._add_weight(
118        name='good_steps', dtype=dtypes.int64, initial_value=0)
119
120  def _add_weight(self, name, initial_value, dtype=None):
121    """Adds a weight to this loss scale.
122
123    Args:
124      name: Variable name.
125      initial_value: The variable's initial value.
126      dtype: The type of the variable.
127
128    Returns:
129      A variable.
130
131    Raises:
132      RuntimeError: If a weight with `name` has already been added.
133    """
134    variable = variable_scope.variable(
135        initial_value=initial_value,
136        name=name,
137        dtype=dtype,
138        trainable=False,
139        use_resource=True,
140        synchronization=variables.VariableSynchronization.AUTO,
141        # Set aggregation to NONE, as loss scaling variables should never be
142        # aggregated.
143        aggregation=variables.VariableAggregation.NONE)
144    if context.executing_eagerly():
145      graph_key = None
146    else:
147      graph = ops.get_default_graph()
148      graph_key = graph._graph_key  # pylint: disable=protected-access
149
150    key = (name, graph_key)
151    self._weights[key] = variable
152    self._handle_deferred_dependencies(name=name, trackable=variable)
153    backend.track_variable(variable)
154    return variable
155
156  def _trackable_children(self,
157                          save_type=trackable.SaveType.CHECKPOINT,
158                          **kwargs):
159    """From Trackable. Gather graph-specific weights to save."""
160    if context.executing_eagerly():
161      graph_key = None
162    else:
163      graph = ops.get_default_graph()
164      graph_key = graph._graph_key  # pylint: disable=protected-access
165    weights = {}
166    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
167      if g == graph_key:
168        weights[name] = v
169    weights.update(
170        super(_DynamicLossScaleState,
171              self)._trackable_children(save_type, **kwargs))
172    return weights
173
174  def _lookup_dependency(self, name):
175    """From Trackable. Find a weight in the current graph."""
176    unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name)
177    if unconditional is not None:
178      return unconditional
179    if context.executing_eagerly():
180      graph_key = None
181    else:
182      graph = ops.get_default_graph()
183      graph_key = graph._graph_key  # pylint: disable=protected-access
184    return self._weights.get((name, graph_key), None)
185
186  @property
187  def initial_loss_scale(self):
188    return self._initial_loss_scale
189
190  @property
191  def growth_steps(self):
192    return self._growth_steps
193
194  @property
195  def multiplier(self):
196    return self._multiplier
197
198  @property
199  def current_loss_scale(self):
200    """Returns the current loss scale as a float32 `tf.Variable`."""
201    return self._current_loss_scale
202
203  @property
204  def counter(self):
205    """Returns the counter as a float32 `tf.Variable`."""
206    return self._counter
207
208  def __call__(self):
209    """Returns the current loss scale as a scalar `float32` tensor."""
210    return ops.convert_to_tensor_v2_with_dispatch(self._current_loss_scale)
211
212  def update(self, grads):
213    """Updates the value of the loss scale.
214
215    Args:
216      grads: A nested structure of unscaled gradients, each which is an
217        all-reduced gradient of the loss with respect to a weight.
218
219    Returns:
220      update_op: In eager mode, None. In graph mode, an op to update the loss
221        scale.
222      should_apply_gradients: Either a bool or a scalar boolean tensor. If
223        False, the caller should skip applying `grads` to the variables this
224        step.
225    """
226    grads = nest.flatten(grads)
227    if distribution_strategy_context.has_strategy(
228    ) and distribution_strategy_context.in_cross_replica_context():
229      distribution = distribution_strategy_context.get_strategy()
230      is_finite_per_replica = distribution.extended.call_for_each_replica(
231          _is_all_finite, args=(grads,))
232      # Each replica computed the same `is_finite` value, since `grads` is
233      # all-reduced across replicas. Arbitrarily take `is_finite` from the first
234      # replica.
235      is_finite = (
236          distribution.experimental_local_results(is_finite_per_replica)[0])
237    else:
238      is_finite = _is_all_finite(grads)
239
240    def update_if_finite_grads():
241      """Update assuming the gradients are finite."""
242
243      def incr_loss_scale():
244        new_loss_scale = self.current_loss_scale * self.multiplier
245        return control_flow_ops.group(
246            _assign_if_finite(self.current_loss_scale, new_loss_scale),
247            self.counter.assign(0))
248
249      return control_flow_ops.cond(
250          self.counter + 1 >= self.growth_steps,
251          incr_loss_scale,
252          lambda: _op_in_graph_mode(self.counter.assign_add(1)))
253
254    def update_if_not_finite_grads():
255      """Update assuming the gradients are nonfinite."""
256
257      new_loss_scale = math_ops.maximum(
258          self.current_loss_scale / self.multiplier, 1)
259      return control_flow_ops.group(
260          self.counter.assign(0),
261          self.current_loss_scale.assign(new_loss_scale))
262
263    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
264                                      update_if_not_finite_grads)
265    should_apply_gradients = is_finite
266    return update_op, should_apply_gradients
267
268
269# See LossScaleOptimizer docstring for why this is so big
270_DEFAULT_INITIAL_SCALE = 2 ** 15
271_DEFAULT_GROWTH_STEPS = 2000
272
273
274# pylint: disable=g-classes-have-attributes
275@keras_export('keras.mixed_precision.LossScaleOptimizer')
276class LossScaleOptimizer(base_delegate.DelegatingTrackableMixin,
277                         optimizer_v2.OptimizerV2):
278  """An optimizer that applies loss scaling to prevent numeric underflow.
279
280  Loss scaling is a technique to prevent numeric underflow in intermediate
281  gradients when float16 is used. To prevent underflow, the loss is multiplied
282  (or "scaled") by a certain factor called the "loss scale", which causes
283  intermediate gradients to be scaled by the loss scale as well. The final
284  gradients are divided (or "unscaled") by the loss scale to bring them back to
285  their original value.
286
287  `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
288  By default, the loss scale is dynamically updated over time so you do not have
289  to choose the loss scale. The `minimize` method automatically scales the loss,
290  unscales the gradients, and updates the loss scale so all you have to do is
291  wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For
292  example:
293
294  >>> opt = tf.keras.optimizers.SGD(0.25)
295  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
296  >>> var = tf.Variable(1.)
297  >>> loss_fn = lambda: var ** 2
298  >>> # 'minimize' applies loss scaling and updates the loss sale.
299  >>> opt.minimize(loss_fn, var_list=var)
300  >>> var.numpy()
301  0.5
302
303  If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you
304  must scale the loss and gradients manually. This can be done with the
305  `LossScaleOptimizer.get_scaled_loss` and
306  `LossScaleOptimizer.get_unscaled_gradients` methods. For example:
307
308  >>> with tf.GradientTape() as tape:
309  ...   loss = loss_fn()
310  ...   scaled_loss = opt.get_scaled_loss(loss)
311  >>> scaled_grad = tape.gradient(scaled_loss, var)
312  >>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
313  >>> opt.apply_gradients([(grad, var)])  # Loss scale is updated here
314  >>> var.numpy()
315  0.25
316
317  Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
318  (or both) when using a `tf.GradientTape`, the model will likely converge to a
319  worse quality. Please make sure you call each function exactly once.
320
321  When mixed precision with float16 is used, there is typically no risk of
322  underflow affecting model quality if loss scaling is properly used. See
323  [the mixed precision guide](
324  https://www.tensorflow.org/guide/keras/mixed_precision) for more information
325  on how to use mixed precision.
326
327  Args:
328    inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap.
329    dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
330      True. If True, the loss scale will be dynamically updated over time using
331      an algorithm that keeps the loss scale at approximately its optimal value.
332      If False, a single fixed loss scale is used and `initial_scale` must be
333      specified, which is used as the loss scale. Recommended to keep as True,
334      as choosing a fixed loss scale can be tricky. Currently, there is a small
335      performance overhead to dynamic loss scaling compared to fixed loss
336      scaling.
337    initial_scale: The initial loss scale. If `dynamic` is True, this defaults
338      to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
339      the sole loss scale, as the loss scale does not change over time. When
340      dynamic loss scaling is used, is better for this to be a very high number,
341      because a loss scale that is too high gets lowered far more quickly than a
342      loss scale that is too low gets raised.
343    dynamic_growth_steps: With dynamic loss scaling, every
344      `dynamic_growth_steps` steps with finite gradients, the loss scale is
345      doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
346      count is reset back to zero, gradients are skipped that step, and the loss
347      scale is halved. The count can be queried with
348      `LossScaleOptimizer.dynamic_counter`. This argument can only be specified
349      if `dynamic` is True.
350
351  `LossScaleOptimizer` will occasionally skip applying gradients to the
352  variables, in which case the trainable variables will not change that step.
353  This is done because the dynamic loss scale will sometimes be raised too
354  high, causing overflow in the gradients. Typically, the first 2 to 15 steps of
355  the model are skipped as the initial loss scale is very high, but afterwards
356  steps will only be skipped on average 0.05% of the time (the fraction of steps
357  skipped is `1 / dynamic_growth_steps`).
358
359  `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
360  optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales
361  the loss and unscales the gradients. In methods `minimize` and
362  `apply_gradients`, it additionally updates the loss scale and skips applying
363  gradients if any gradient has a nonfinite value.
364
365  ### Hyperparameters
366
367  Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
368  be delegated to the wrapped optimizer.
369
370  >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
371  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
372  >>> opt.beta_1  # Equivalent to `opt.inner_optimizer.beta_1`
373  0.8
374  >>> opt.beta_1 = 0.7  # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
375  >>> opt.beta_1
376  0.7
377  >>> opt.inner_optimizer.beta_1
378  0.7
379
380  However, accessing or setting non-hyperparameters is not delegated to the
381  LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but
382  `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
383  `beta_1`.
384
385  >>> opt.inner_optimizer.epsilon
386  1e-5
387  >>> opt.epsilon
388  Traceback (most recent call last):
389  ...
390  AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
391  >>> opt.epsilon = 1e-4  # This does NOT set epsilon on `opt.inner_optimizer`
392  >>> opt.inner_optimizer.epsilon
393  >>> 1e-5
394
395  In the above example, despite epsilon being set on the LossScaleOptimizer, the
396  old epsilon value will still be used when training as epsilon was not set on
397  the inner optimizer.
398  """
399
400  _HAS_AGGREGATE_GRAD = True
401
402  def __init__(self, inner_optimizer, dynamic=True, initial_scale=None,
403               dynamic_growth_steps=None):
404    if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
405      raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, '
406                      'but got: %s' % inner_optimizer)
407    if not isinstance(dynamic, bool):
408      # Catch errors if a user incorrectly passes a string or float to the
409      # second argument argument, as this is commonly done for
410      # LossScaleOptimizerV1.
411      raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must '
412                      'be a bool, but got: %r' % (dynamic,))
413    if isinstance(inner_optimizer, LossScaleOptimizer):
414      raise TypeError('LossScaleOptimizer cannot wrap another '
415                      'LossScaleOptimizer, but got: %s' % (inner_optimizer,))
416    self._raise_if_strategy_unsupported()
417    if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False):
418      # TODO(reedwm): Maybe support this. The difficulty is that LSO has the
419      # same checkpoint format as the inner optimizer, so multiple LSOs wrapping
420      # the same optimizer causes the checkpointing logic to become confused.
421      raise ValueError('"inner_optimizer" is already wrapped by a '
422                       'LossScaleOptimizer. An optimizer can only be wrapped '
423                       'by a single LossScaleOptimizer')
424    self._optimizer = inner_optimizer
425    self._optimizer._is_wrapped_by_loss_scale_optimizer = True
426
427    # We don't call super().__init__, since we do not want to call OptimizerV2's
428    # constructor.
429    base_delegate.DelegatingTrackableMixin.__init__(self, self._optimizer)
430
431    if dynamic:
432      if initial_scale is None:
433        initial_scale = _DEFAULT_INITIAL_SCALE
434      if dynamic_growth_steps is None:
435        dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
436      self._loss_scale = _DynamicLossScaleState(
437          initial_scale, dynamic_growth_steps, multiplier=2)
438      self._track_trackable(self._loss_scale, 'loss_scale')
439    else:
440      if initial_scale is None:
441        raise ValueError('"initial_scale" must be specified if "dynamic" is '
442                         'False')
443      self._loss_scale = float(initial_scale)
444      if dynamic_growth_steps is not None:
445        raise ValueError('"dynamic_growth_steps" must be None if "dynamic" '
446                         'is False, but got: %s' % (dynamic_growth_steps,))
447
448    # To support restoring TensorFlow 2.2 checkpoints.
449    self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
450                          'base_optimizer')
451
452  @property
453  def dynamic(self):
454    """Bool indicating whether dynamic loss scaling is used."""
455    return isinstance(self._loss_scale, _DynamicLossScaleState)
456
457  @property
458  def loss_scale(self):
459    """The current loss scale as a float32 scalar tensor."""
460    if isinstance(self._loss_scale, _DynamicLossScaleState):
461      return ops.convert_to_tensor_v2_with_dispatch(
462          self._loss_scale.current_loss_scale)
463    else:
464      return ops.convert_to_tensor_v2_with_dispatch(self._loss_scale)
465
466  @property
467  def dynamic_counter(self):
468    """The number of steps since the loss scale was last increased or decreased.
469
470    This is None if `LossScaleOptimizer.dynamic` is False.
471
472    The counter is incremented every step. Once it reaches
473    `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled
474    and the counter will be reset back to zero. If nonfinite gradients are
475    encountered, the loss scale will be halved and the counter will be reset
476    back to zero.
477    """
478    if isinstance(self._loss_scale, _DynamicLossScaleState):
479      return self._loss_scale.counter
480    else:
481      return None
482
483  @property
484  def initial_scale(self):
485    """The initial loss scale.
486
487    If `LossScaleOptimizer.dynamic` is False, this is the same number as
488    `LossScaleOptimizer.loss_scale`, as the loss scale never changes.
489    """
490    if isinstance(self._loss_scale, _DynamicLossScaleState):
491      return self._loss_scale.initial_loss_scale
492    else:
493      return self._loss_scale
494
495  @property
496  def dynamic_growth_steps(self):
497    """The number of steps it takes to increase the loss scale.
498
499    This is None if `LossScaleOptimizer.dynamic` is False.
500
501    Every `dynamic_growth_steps` consecutive steps with finite gradients, the
502    loss scale is increased.
503    """
504    if isinstance(self._loss_scale, _DynamicLossScaleState):
505      return self._loss_scale.growth_steps
506    else:
507      return None
508
509  @property
510  def inner_optimizer(self):
511    """The optimizer that this LossScaleOptimizer is wrapping."""
512    return self._optimizer
513
514  def get_scaled_loss(self, loss):
515    """Scales the loss by the loss scale.
516
517    This method is only needed if you compute gradients manually, e.g. with
518    `tf.GradientTape`. In that case, call this method to scale the loss before
519    passing the loss to `tf.GradientTape`. If you use
520    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
521    scaling is automatically applied and this method is unneeded.
522
523    If this method is called, `get_unscaled_gradients` should also be called.
524    See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
525    an example.
526
527    Args:
528      loss: The loss, which will be multiplied by the loss scale. Can either be
529        a tensor or a callable returning a tensor.
530
531    Returns:
532      `loss` multiplied by `LossScaleOptimizer.loss_scale`.
533    """
534    if callable(loss):
535      def new_loss():
536        loss_val = loss()
537        return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype)
538      return new_loss
539    else:
540      return loss * math_ops.cast(self.loss_scale, loss.dtype)
541
542  def get_unscaled_gradients(self, grads):
543    """Unscales the gradients by the loss scale.
544
545    This method is only needed if you compute gradients manually, e.g. with
546    `tf.GradientTape`. In that case, call this method to unscale the gradients
547    after computing them with `tf.GradientTape`. If you use
548    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
549    scaling is automatically applied and this method is unneeded.
550
551    If this method is called, `get_scaled_loss` should also be called. See
552    the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
553    example.
554
555    Args:
556      grads: A list of tensors, each which will be divided by the loss scale.
557        Can have None values, which are ignored.
558
559    Returns:
560      A new list the same size as `grads`, where every non-None value in `grads`
561      is divided by `LossScaleOptimizer.loss_scale`.
562    """
563    loss_scale_reciprocal = 1. / self.loss_scale
564    return [
565        _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
566        for g in grads
567    ]
568
569  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
570    tape = backprop.GradientTape() if tape is None else tape
571    with tape:
572      loss = self.get_scaled_loss(loss)
573    grads_and_vars = self._optimizer._compute_gradients(  # pylint: disable=protected-access
574        loss,
575        var_list,
576        grad_loss,
577        tape=tape)
578    grads = [g for g, _ in grads_and_vars]
579    weights = [v for _, v in grads_and_vars]
580    unscaled_grads = self.get_unscaled_gradients(grads)
581    return list(zip(unscaled_grads, weights))
582
583  def get_gradients(self, loss, params):
584    loss = self.get_scaled_loss(loss)
585    grads = self._optimizer.get_gradients(loss, params)
586    return self.get_unscaled_gradients(grads)
587
588  def _create_all_weights(self, var_list):
589    self._optimizer._create_all_weights(var_list)    # pylint: disable=protected-access
590
591  def apply_gradients(self,
592                      grads_and_vars,
593                      name=None,
594                      experimental_aggregate_gradients=True):
595    if distribution_strategy_context.in_cross_replica_context():
596      raise ValueError('apply_gradients() must be called in a replica context.')
597    # We check for the strategy here despite already checking in the constructor
598    # as frequently the optimizer is created outside the strategy's scope.
599    self._raise_if_strategy_unsupported()
600
601    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
602    if experimental_aggregate_gradients:
603      # We must aggregate the gradients here instead of in
604      # self.optimizer.apply_gradients, so that any NaN or Inf gradients are
605      # propogated to each replica. If any replica has a NaN or Inf gradient,
606      # they must all have a NaN or Inf gradient so that they all skip the step.
607      # pylint: disable=protected-access
608      grads_and_vars = self._optimizer._transform_unaggregated_gradients(
609          grads_and_vars)
610      grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars)
611      # pylint: enable=protected-access
612
613    grads_and_vars = tuple(grads_and_vars)
614    grads = [g for g, _ in grads_and_vars]
615    # We do not want DistributionStrategy to unwrap any MirroredVariables in
616    # grads_and_vars, because even in a replica context, the wrapped
617    # optimizer expects mirrored variables. So we wrap the variables with an
618    # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
619    # MirroredVariables.
620    wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
621
622    def do_not_apply_fn():
623      # Normally self._optimizer.iterations is incremented in
624      # self._optimizer.apply_gradients(). Since that is not called in this
625      # branch, we increment it here instead.
626      return self._optimizer.iterations.assign_add(1, read_value=False)
627
628    def _if_should_apply_grads(grads):
629      if isinstance(self._loss_scale, _DynamicLossScaleState):
630        return self._loss_scale.update(grads)
631      else:
632        return (control_flow_ops.no_op(), True)
633
634    if optimizer_utils.strategy_supports_no_merge_call():
635      loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
636      def apply_fn():
637        return self._apply_gradients(grads, wrapped_vars, name)
638
639      maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
640                                             do_not_apply_fn)
641      return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
642
643    else:
644
645      def _apply_gradients_cross_replica(distribution, grads, wrapped_vars,
646                                         name):
647        loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
648
649        def apply_fn():
650          return distribution.extended.call_for_each_replica(
651              self._apply_gradients,
652              args=(grads, wrapped_vars, name))
653
654        # Note: We must call this cond() in a cross-replica context.
655        # DistributionStrategy does not support having a cond in a replica
656        # context with a branch that calls `merge_call`, and
657        # self._optimizer.apply_gradients calls `merge_call`.
658        maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
659                                               do_not_apply_fn)
660        return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
661      return distribution_strategy_context.get_replica_context().merge_call(
662          _apply_gradients_cross_replica,
663          args=(grads, wrapped_vars, name))
664
665  def _apply_gradients(self, grads, wrapped_vars, name):
666    # Pass experimental_aggregate_gradients=False since LossScaleOptimizer
667    # already aggregated the gradients.
668    # TODO(reedwm): This will raise a fairly cryptic error message if
669    # self._optimizer.apply_gradients does not take
670    # experimental_aggregate_gradients.
671    return self._optimizer.apply_gradients(
672        list(zip(grads, wrapped_vars.value)), name,
673        experimental_aggregate_gradients=False)
674
675  def get_config(self):
676    serialized_optimizer = optimizers.serialize(self._optimizer)
677    return {
678        'inner_optimizer': serialized_optimizer,
679        'dynamic': self.dynamic,
680        'initial_scale': self.initial_scale,
681        'dynamic_growth_steps': self.dynamic_growth_steps,
682    }
683
684  @classmethod
685  def from_config(cls, config, custom_objects=None):
686    config = config.copy()  # Make a copy, since we mutate config
687    if 'loss_scale' in config:
688      # If loss_scale is in config, we assume we are deserializing a
689      # LossScaleOptimizer from TF 2.3 or below. We convert the config so it
690      # can be deserialized in the current LossScaleOptimizer.
691      loss_scale = keras_loss_scale_module.deserialize(
692          config.pop('loss_scale'))
693      if isinstance(loss_scale, loss_scale_module.FixedLossScale):
694        config['dynamic'] = False
695        config['initial_scale'] = loss_scale._loss_scale_value  # pylint: disable=protected-access
696      elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
697        config['dynamic'] = True
698        config['initial_scale'] = loss_scale.initial_loss_scale
699        config['dynamic_growth_steps'] = loss_scale.increment_period
700        if loss_scale.multiplier != 2:
701          raise ValueError('Cannot deserialize LossScaleOptimizer with a '
702                           'DynamicLossScale whose multiplier is not 2. Got '
703                           'DynamicLossScale: %s' % (loss_scale,))
704      else:
705        raise ValueError(
706            'Serialized LossScaleOptimizers with a LossScale that is neither a '
707            'FixedLossScale nor a DynamicLossScale can no longer be '
708            'deserialized')
709      config['inner_optimizer'] = config.pop('optimizer')
710    config['inner_optimizer'] = optimizers.deserialize(
711        config['inner_optimizer'], custom_objects=custom_objects)
712    return cls(**config)
713
714  def _raise_if_strategy_unsupported(self):
715    if not strategy_supports_loss_scaling():
716      strategy = distribution_strategy_context.get_strategy()
717      if isinstance(strategy,
718                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
719                     tpu_strategy.TPUStrategyV2)):
720        raise ValueError(
721            'Loss scaling is not supported with TPUStrategy. Loss scaling is '
722            'unnecessary with TPUs, since they support bfloat16 instead of '
723            'float16 and bfloat16 does not require loss scaling. You should '
724            'remove the use of the LossScaleOptimizer when TPUs are used.')
725      else:
726        raise ValueError('Loss scaling is not supported with the '
727                         'tf.distribute.Strategy: %s. Try using a different '
728                         'Strategy, e.g. a MirroredStrategy' %
729                         strategy.__class__.__name__)
730
731  # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
732  # below.
733
734  @property
735  def iterations(self):
736    return self._optimizer.iterations
737
738  @iterations.setter
739  def iterations(self, variable):
740    self._optimizer.iterations = variable
741
742  def get_slot_names(self):
743    return self._optimizer.get_slot_names()
744
745  def variables(self):
746    return self._optimizer.variables()
747
748  @property
749  def weights(self):
750    return self._optimizer.weights
751
752  def get_weights(self):
753    return self._optimizer.get_weights()
754
755  def set_weights(self, weights):
756    return self._optimizer.set_weights(weights)
757
758  @property
759  def clipnorm(self):
760    return self._optimizer.clipnorm
761
762  @clipnorm.setter
763  def clipnorm(self, val):
764    self._optimizer.clipnorm = val
765
766  @property
767  def global_clipnorm(self):
768    return self._optimizer.global_clipnorm
769
770  @global_clipnorm.setter
771  def global_clipnorm(self, val):
772    self._optimizer.global_clipnorm = val
773
774  @property
775  def clipvalue(self):
776    return self._optimizer.clipvalue
777
778  @clipvalue.setter
779  def clipvalue(self, val):
780    self._optimizer.clipvalue = val
781
782  def _aggregate_gradients(self, grads_and_vars):
783    return self._optimizer._aggregate_gradients(grads_and_vars)  # pylint: disable=protected-access
784
785  def _restore_slot_variable(self, slot_name, variable, slot_variable):
786    return self._optimizer._restore_slot_variable(slot_name, variable,  # pylint: disable=protected-access
787                                                  slot_variable)
788
789  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
790                                       variable):
791    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
792        slot_variable_position, slot_name, variable)
793
794  def get_slot(self, var, slot_name):
795    return self._optimizer.get_slot(var, slot_name)
796
797  def add_slot(self, var, slot_name, initializer='zeros'):
798    return self._optimizer.add_slot(var, slot_name, initializer)
799
800  def __getattribute__(self, name):
801    try:
802      return object.__getattribute__(self, name)
803    except AttributeError as e:
804      if name == '_optimizer' or name == '_hyper':
805        # Avoid infinite recursion
806        raise e
807
808      # Delegate hyperparameter accesses to inner optimizer.
809      if name == 'lr':
810        name = 'learning_rate'
811      if name in self._optimizer._hyper:
812        return self._optimizer._get_hyper(name)
813      raise e
814
815  def __dir__(self):
816    result = set(super(LossScaleOptimizer, self).__dir__())
817    if '_optimizer' in result:
818      result |= self._optimizer._hyper.keys()
819      if 'learning_rate' in self._optimizer._hyper.keys():
820        result.add('lr')
821    return list(result)
822
823  def __setattr__(self, name, value):
824    if name == 'lr':
825      name = 'learning_rate'
826    # Delegate setting hyperparameter to inner optimizer if the attribute does
827    # not exist on the LossScaleOptimizer
828    try:
829      # We cannot check for the 'iterations' attribute as it cannot be set after
830      # it is accessed.
831      if name != 'iterations':
832        object.__getattribute__(self, name)
833      has_attribute = True
834    except AttributeError:
835      has_attribute = False
836    if (name != '_optimizer' and name in self._optimizer._hyper
837        and not has_attribute):
838      self._optimizer._set_hyper(name, value)
839    else:
840      super(LossScaleOptimizer, self).__setattr__(name, value)
841
842  # Explicitly delegate learning_rate. Normally hyperparameters are delegated in
843  # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper
844  # (e.g. because self._optimizer itself wraps another optimizer), then it won't
845  # be delegated. Since learning_rate is a very commonly accessed
846  # hyperparameter, we delegate it here.
847  @property
848  def learning_rate(self):
849    return self._optimizer.learning_rate
850
851  @learning_rate.setter
852  def learning_rate(self, value):
853    self._optimizer.learning_rate = value
854
855  @property
856  def lr(self):
857    return self._optimizer.learning_rate
858
859  @lr.setter
860  def lr(self, value):
861    self._optimizer.lr = value
862
863  # We do not override some OptimizerV2 methods. For each, we describe why we do
864  # not delegate them to self._optimizer:
865  # * get_updates: get_updates() calls get_gradients(). Since we override
866  #   get_gradients(), we cannot delegate get_updates() to self._optimizer,
867  #   otherwise the overridden get_gradients() method would not be called.
868  #   Luckily, get_updates() does not access any OptimizerV2 fields, so
869  #   inheriting the OptimizerV2 version works fine.
870  # * minimize: We don't delegate for a similar as get_updates(): it calls
871  #   both self._compute_gradients() and self.apply_gradients(), and both need
872  #   to have the LossScaleOptimizer version called.
873
874  # TODO(reedwm): Maybe throw an error if mixed precision is used without this
875  # optimizer being used.
876
877
878@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
879class LossScaleOptimizerV1(LossScaleOptimizer):
880  """An deprecated optimizer that applies loss scaling.
881
882  Warning: This class is deprecated and will be removed in a future version of
883  TensorFlow. Please use the non-experimental class
884  `tf.keras.mixed_precision.LossScaleOptimizer` instead.
885
886  This class is identical to the non-experimental
887  `keras.mixed_precision.LossScaleOptimizer` except its constructor takes
888  different arguments. For this class (the experimental version), the
889  constructor takes a `loss_scale` argument.  For the non-experimental class,
890  the constructor encodes the loss scaling information in multiple arguments.
891  Note that unlike this class, the non-experimental class does not accept a
892  `tf.compat.v1.mixed_precision.LossScale`, which is deprecated.
893
894  If you currently use this class, you should switch to the non-experimental
895  `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several
896  examples of converting the use of the experimental class to the equivalent
897  non-experimental class.
898
899  >>> # In all of the examples below, `opt1` and `opt2` are identical
900  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
901  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
902  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
903  ...     tf.keras.optimizers.SGD())
904  >>> assert opt1.get_config() == opt2.get_config()
905
906  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
907  ...     tf.keras.optimizers.SGD(), loss_scale=123)
908  >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123
909  >>> # refers to the initial loss scale, which is the single fixed loss scale
910  >>> # when dynamic=False.
911  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
912  ...     tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123)
913  >>> assert opt1.get_config() == opt2.get_config()
914
915  >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale(
916  ...     initial_loss_scale=2048, increment_period=500)
917  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
918  ...     tf.keras.optimizers.SGD(), loss_scale=loss_scale)
919  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
920  ...     tf.keras.optimizers.SGD(), initial_scale=2048,
921  ...     dynamic_growth_steps=500)
922  >>> assert opt1.get_config() == opt2.get_config()
923
924  Make sure to also switch from this class to the non-experimental class in
925  isinstance checks, if you have any. If you do not do this, your model may run
926  into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses
927  the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to
928  switch isinstance checks to the non-experimental `LossScaleOptimizer` even
929  before using the non-experimental `LossScaleOptimizer`.
930
931  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
932  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
933  >>> # The experimental class subclasses the non-experimental class
934  >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer)
935  True
936  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
937  ...     tf.keras.optimizers.SGD())
938  >>> # The non-experimental class does NOT subclass the experimental class.
939  >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
940  False
941
942  Args:
943    optimizer: The Optimizer instance to wrap.
944    loss_scale: The loss scale to scale the loss and gradients. This can
945      either be an int/float to use a fixed loss scale, the string "dynamic"
946      to use dynamic loss scaling, or an instance of a LossScale. The string
947      "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
948      int/float is equivalent to passing a FixedLossScale with the given loss
949      scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must
950      be 2 (the default).
951  """
952
953  def __init__(self, optimizer, loss_scale):
954    warn_msg_prefix = (
955        'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
956        'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
957        'instead. ')
958
959    if isinstance(loss_scale, dict):
960      loss_scale = keras_loss_scale_module.deserialize(loss_scale)
961
962    if isinstance(loss_scale, (int, float)):
963      tf_logging.warning(
964          warn_msg_prefix + 'For example:\n'
965          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
966          'opt, dynamic=False, initial_scale={})'.format(loss_scale))
967      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
968                                                 initial_scale=loss_scale)
969    elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
970      ls_val = loss_scale._loss_scale_value  # pylint: disable=protected-access
971      tf_logging.warning(
972          warn_msg_prefix + 'For example:\n'
973          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
974          'opt, dynamic=False, initial_scale={})'.format(ls_val))
975      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
976                                                 initial_scale=ls_val)
977    elif loss_scale == 'dynamic':
978      tf_logging.warning(
979          warn_msg_prefix + 'For example:\n'
980          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
981          'opt)')
982      super(LossScaleOptimizerV1, self).__init__(optimizer)
983    elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
984      kwargs = {}
985      extra_arguments = ''
986      if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
987        kwargs['initial_scale'] = loss_scale.initial_loss_scale
988        extra_arguments += (', initial_scale=%s' %
989                            loss_scale.initial_loss_scale)
990      if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
991        kwargs['dynamic_growth_steps'] = loss_scale.increment_period
992        extra_arguments += (', dynamic_growth_steps=%s' %
993                            loss_scale.increment_period)
994      if loss_scale.multiplier != 2:
995        raise ValueError('When passing a DynamicLossScale to "loss_scale", '
996                         'DynamicLossScale.multiplier must be 2. Got: %s'
997                         % (loss_scale,))
998      tf_logging.warning(
999          warn_msg_prefix +
1000          'Note that the non-experimental LossScaleOptimizer does not take a '
1001          'DynamicLossScale but instead takes the dynamic configuration '
1002          'directly in the constructor. For example:\n'
1003          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
1004          'opt{})\n'.format(extra_arguments))
1005      super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
1006    elif isinstance(loss_scale, loss_scale_module.LossScale):
1007      raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
1008                      'DynamicLossScale is no longer supported. Got: {}'
1009                      .format(loss_scale))
1010    else:
1011      raise ValueError('Invalid value passed to loss_scale. loss_scale '
1012                       'must be the string "dynamic" (recommended), an int, '
1013                       'a float, a FixedLossScale, or a DynamicLossScale. Got '
1014                       'value: {}'.format(loss_scale))
1015
1016  @classmethod
1017  def from_config(cls, config, custom_objects=None):
1018    config = config.copy()  # Make a copy, since we mutate config
1019
1020    # If loss_scale is in config, we assume we are deserializing a
1021    # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
1022    # deserializing a LossScaleOptimizer from TF 2.4 or above.
1023    if 'loss_scale' in config:
1024      config['loss_scale'] = keras_loss_scale_module.deserialize(
1025          config['loss_scale'])
1026      if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
1027          and config['loss_scale'].multiplier != 2):
1028        raise ValueError('Cannot deserialize LossScaleOptimizer with a '
1029                         'DynamicLossScale whose multiplier is not 2. Got '
1030                         'DynamicLossScale: %s' % (config['loss_scale'],))
1031      config['optimizer'] = optimizers.deserialize(
1032          config['optimizer'], custom_objects=custom_objects)
1033      return cls(**config)
1034
1035    # We convert the config, as generated by LossScaleOptimizer.get_config, to a
1036    # version that can be passed to LossScaleOptimizerV1.__init__
1037    if config['dynamic']:
1038      config['loss_scale'] = loss_scale_module.DynamicLossScale(
1039          config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
1040    else:
1041      config['loss_scale'] = loss_scale_module.FixedLossScale(
1042          config['initial_scale'])
1043
1044    del config['dynamic']
1045    del config['initial_scale']
1046    del config['dynamic_growth_steps']
1047    config['optimizer'] = optimizers.deserialize(
1048        config.pop('inner_optimizer'), custom_objects=custom_objects)
1049    return cls(**config)
1050
1051
1052class FakeOptimizerForRestoration(trackable.Trackable):
1053  """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
1054
1055  The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
1056  exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
1057
1058  In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
1059  following in LossScaleOptimizer.__init__
1060
1061  ```
1062  self._track_trackable(self._optimizer, 'base_optimizer')
1063  ```
1064
1065  This means a dependency from the LossScaleOptimizer to the wrapped optimizer
1066  would be stored in the checkpoint. However now, the checkpoint format with a
1067  LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
1068  except the loss scale is also stored. This means there is no dependency from
1069  the LossScaleOptimizer to the wrapped optimizer. Instead, the
1070  LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
1071  perspective, by overriding all Trackable methods and delegating them to the
1072  wrapped optimizer.
1073
1074  To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
1075  on this class instead of the inner optimizer. When restored, this class will
1076  instead restore the slot variables of the inner optimizer. Since this class
1077  has no variables, it does not affect the checkpoint when saved.
1078  """
1079
1080  def __init__(self, optimizer):
1081    self._optimizer = optimizer
1082
1083  def get_slot_names(self):
1084    return self._optimizer.get_slot_names()
1085
1086  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
1087                                       variable):
1088    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
1089        slot_variable_position, slot_name, variable)
1090
1091
1092mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2,
1093                                            LossScaleOptimizerV1)
1094
1095
1096def _multiply_gradient(gradient, scale):
1097  """Multiply a (possibly sparse) gradient by the given scale factor."""
1098  scale = math_ops.cast(scale, gradient.dtype)
1099  if isinstance(gradient, indexed_slices.IndexedSlices):
1100    return indexed_slices.IndexedSlices(
1101        gradient.values * scale,
1102        gradient.indices,
1103        dense_shape=gradient.dense_shape)
1104  else:
1105    return gradient * scale
1106
1107
1108def strategy_supports_loss_scaling():
1109  """Returns True if the current Strategy supports loss scaling."""
1110  if not distribution_strategy_context.has_strategy():
1111    return True
1112  strategy = distribution_strategy_context.get_strategy()
1113  # Strategies are supported if either there is only one replica or if variables
1114  # are replicated per device. Otherwise, the current model.fit() implementation
1115  # and most custom training loops incorrectly unscale the gradients. Currently,
1116  # gradients are unscaled once per compute replica, but they should be unscaled
1117  # once per variable replica. When there is one variable replica for each
1118  # compute replica, this works fine, but otherwise issues will occur.
1119  # TODO(reedwm): Support all strategies.
1120  return isinstance(strategy, (
1121      collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1122      collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
1123      one_device_strategy.OneDeviceStrategy,
1124      one_device_strategy.OneDeviceStrategyV1,
1125      mirrored_strategy.MirroredStrategy,
1126      mirrored_strategy.MirroredStrategyV1,
1127  ))
1128