1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the loss scaling optimizer class.""" 16 17from tensorflow.python.distribute import collective_all_reduce_strategy 18from tensorflow.python.distribute import distribution_strategy_context 19from tensorflow.python.distribute import mirrored_strategy 20from tensorflow.python.distribute import one_device_strategy 21from tensorflow.python.distribute import tpu_strategy 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import context 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import indexed_slices 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import smart_cond 28from tensorflow.python.keras import backend 29from tensorflow.python.keras import optimizers 30from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 31from tensorflow.python.keras.optimizer_v2 import optimizer_v2 32from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import tf_logging 38from tensorflow.python.trackable import base as trackable 39from tensorflow.python.trackable import base_delegate 40from tensorflow.python.training.experimental import loss_scale as loss_scale_module 41from tensorflow.python.training.experimental import mixed_precision 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import keras_export 44 45 46class _UnwrapPreventer(object): 47 """Wrapper that DistributionStrategy will not unwrap. 48 49 Typically, DistributionStrategy will unwrap values when going from a cross- 50 replica context to a replica context via `call_for_each_replica`. This class 51 is a wrapper that DistributionStrategy will not unwrap, so it can be used to 52 prevent it from unwrapping a value. 53 54 TODO(reedwm): Find/implement a better way of preventing values from being 55 unwrapped by DistributionStrategy 56 """ 57 58 __slots__ = ['value'] 59 60 def __init__(self, value): 61 self.value = value 62 63 64def _is_all_finite(grads): 65 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 66 is_finite_per_grad = [ 67 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 68 ] 69 return math_ops.reduce_all(is_finite_per_grad) 70 71 72def _op_in_graph_mode(tensor): 73 """Returns the tensor's op in graph mode, or the tensor in eager mode. 74 75 This is useful because sometimes an op is needed in graph mode instead of a 76 tensor. In eager mode, there are no ops. 77 78 Args: 79 tensor: A tensor. 80 81 Returns: 82 The tensor's op in graph mode. The tensor in eager mode. 83 """ 84 if context.executing_eagerly(): 85 return tensor 86 return tensor.op 87 88 89def _assign_if_finite(var, value): 90 """Assigns a value to a variable if the value is finite.""" 91 return control_flow_ops.cond( 92 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 93 control_flow_ops.no_op) 94 95 96class _DynamicLossScaleState(trackable.Trackable): 97 """The state of a dynamic loss scale.""" 98 99 def __init__(self, 100 initial_loss_scale, 101 growth_steps, 102 multiplier): 103 """Creates the dynamic loss scale.""" 104 super(_DynamicLossScaleState, self).__init__() 105 self._initial_loss_scale = float(initial_loss_scale) 106 self._growth_steps = int(growth_steps) 107 self._multiplier = float(multiplier) 108 109 self._weights = {} 110 self._current_loss_scale = self._add_weight( 111 name='current_loss_scale', 112 dtype=dtypes.float32, 113 initial_value=self._initial_loss_scale) 114 # The number of consecutive steps with finite gradients since the last 115 # nonfinite gradient or change in loss scale. The name is 'good_steps' for 116 # backwards compatibility with older checkpoints. 117 self._counter = self._add_weight( 118 name='good_steps', dtype=dtypes.int64, initial_value=0) 119 120 def _add_weight(self, name, initial_value, dtype=None): 121 """Adds a weight to this loss scale. 122 123 Args: 124 name: Variable name. 125 initial_value: The variable's initial value. 126 dtype: The type of the variable. 127 128 Returns: 129 A variable. 130 131 Raises: 132 RuntimeError: If a weight with `name` has already been added. 133 """ 134 variable = variable_scope.variable( 135 initial_value=initial_value, 136 name=name, 137 dtype=dtype, 138 trainable=False, 139 use_resource=True, 140 synchronization=variables.VariableSynchronization.AUTO, 141 # Set aggregation to NONE, as loss scaling variables should never be 142 # aggregated. 143 aggregation=variables.VariableAggregation.NONE) 144 if context.executing_eagerly(): 145 graph_key = None 146 else: 147 graph = ops.get_default_graph() 148 graph_key = graph._graph_key # pylint: disable=protected-access 149 150 key = (name, graph_key) 151 self._weights[key] = variable 152 self._handle_deferred_dependencies(name=name, trackable=variable) 153 backend.track_variable(variable) 154 return variable 155 156 def _trackable_children(self, 157 save_type=trackable.SaveType.CHECKPOINT, 158 **kwargs): 159 """From Trackable. Gather graph-specific weights to save.""" 160 if context.executing_eagerly(): 161 graph_key = None 162 else: 163 graph = ops.get_default_graph() 164 graph_key = graph._graph_key # pylint: disable=protected-access 165 weights = {} 166 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 167 if g == graph_key: 168 weights[name] = v 169 weights.update( 170 super(_DynamicLossScaleState, 171 self)._trackable_children(save_type, **kwargs)) 172 return weights 173 174 def _lookup_dependency(self, name): 175 """From Trackable. Find a weight in the current graph.""" 176 unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) 177 if unconditional is not None: 178 return unconditional 179 if context.executing_eagerly(): 180 graph_key = None 181 else: 182 graph = ops.get_default_graph() 183 graph_key = graph._graph_key # pylint: disable=protected-access 184 return self._weights.get((name, graph_key), None) 185 186 @property 187 def initial_loss_scale(self): 188 return self._initial_loss_scale 189 190 @property 191 def growth_steps(self): 192 return self._growth_steps 193 194 @property 195 def multiplier(self): 196 return self._multiplier 197 198 @property 199 def current_loss_scale(self): 200 """Returns the current loss scale as a float32 `tf.Variable`.""" 201 return self._current_loss_scale 202 203 @property 204 def counter(self): 205 """Returns the counter as a float32 `tf.Variable`.""" 206 return self._counter 207 208 def __call__(self): 209 """Returns the current loss scale as a scalar `float32` tensor.""" 210 return ops.convert_to_tensor_v2_with_dispatch(self._current_loss_scale) 211 212 def update(self, grads): 213 """Updates the value of the loss scale. 214 215 Args: 216 grads: A nested structure of unscaled gradients, each which is an 217 all-reduced gradient of the loss with respect to a weight. 218 219 Returns: 220 update_op: In eager mode, None. In graph mode, an op to update the loss 221 scale. 222 should_apply_gradients: Either a bool or a scalar boolean tensor. If 223 False, the caller should skip applying `grads` to the variables this 224 step. 225 """ 226 grads = nest.flatten(grads) 227 if distribution_strategy_context.has_strategy( 228 ) and distribution_strategy_context.in_cross_replica_context(): 229 distribution = distribution_strategy_context.get_strategy() 230 is_finite_per_replica = distribution.extended.call_for_each_replica( 231 _is_all_finite, args=(grads,)) 232 # Each replica computed the same `is_finite` value, since `grads` is 233 # all-reduced across replicas. Arbitrarily take `is_finite` from the first 234 # replica. 235 is_finite = ( 236 distribution.experimental_local_results(is_finite_per_replica)[0]) 237 else: 238 is_finite = _is_all_finite(grads) 239 240 def update_if_finite_grads(): 241 """Update assuming the gradients are finite.""" 242 243 def incr_loss_scale(): 244 new_loss_scale = self.current_loss_scale * self.multiplier 245 return control_flow_ops.group( 246 _assign_if_finite(self.current_loss_scale, new_loss_scale), 247 self.counter.assign(0)) 248 249 return control_flow_ops.cond( 250 self.counter + 1 >= self.growth_steps, 251 incr_loss_scale, 252 lambda: _op_in_graph_mode(self.counter.assign_add(1))) 253 254 def update_if_not_finite_grads(): 255 """Update assuming the gradients are nonfinite.""" 256 257 new_loss_scale = math_ops.maximum( 258 self.current_loss_scale / self.multiplier, 1) 259 return control_flow_ops.group( 260 self.counter.assign(0), 261 self.current_loss_scale.assign(new_loss_scale)) 262 263 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 264 update_if_not_finite_grads) 265 should_apply_gradients = is_finite 266 return update_op, should_apply_gradients 267 268 269# See LossScaleOptimizer docstring for why this is so big 270_DEFAULT_INITIAL_SCALE = 2 ** 15 271_DEFAULT_GROWTH_STEPS = 2000 272 273 274# pylint: disable=g-classes-have-attributes 275@keras_export('keras.mixed_precision.LossScaleOptimizer') 276class LossScaleOptimizer(base_delegate.DelegatingTrackableMixin, 277 optimizer_v2.OptimizerV2): 278 """An optimizer that applies loss scaling to prevent numeric underflow. 279 280 Loss scaling is a technique to prevent numeric underflow in intermediate 281 gradients when float16 is used. To prevent underflow, the loss is multiplied 282 (or "scaled") by a certain factor called the "loss scale", which causes 283 intermediate gradients to be scaled by the loss scale as well. The final 284 gradients are divided (or "unscaled") by the loss scale to bring them back to 285 their original value. 286 287 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. 288 By default, the loss scale is dynamically updated over time so you do not have 289 to choose the loss scale. The `minimize` method automatically scales the loss, 290 unscales the gradients, and updates the loss scale so all you have to do is 291 wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For 292 example: 293 294 >>> opt = tf.keras.optimizers.SGD(0.25) 295 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 296 >>> var = tf.Variable(1.) 297 >>> loss_fn = lambda: var ** 2 298 >>> # 'minimize' applies loss scaling and updates the loss sale. 299 >>> opt.minimize(loss_fn, var_list=var) 300 >>> var.numpy() 301 0.5 302 303 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you 304 must scale the loss and gradients manually. This can be done with the 305 `LossScaleOptimizer.get_scaled_loss` and 306 `LossScaleOptimizer.get_unscaled_gradients` methods. For example: 307 308 >>> with tf.GradientTape() as tape: 309 ... loss = loss_fn() 310 ... scaled_loss = opt.get_scaled_loss(loss) 311 >>> scaled_grad = tape.gradient(scaled_loss, var) 312 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad]) 313 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here 314 >>> var.numpy() 315 0.25 316 317 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` 318 (or both) when using a `tf.GradientTape`, the model will likely converge to a 319 worse quality. Please make sure you call each function exactly once. 320 321 When mixed precision with float16 is used, there is typically no risk of 322 underflow affecting model quality if loss scaling is properly used. See 323 [the mixed precision guide]( 324 https://www.tensorflow.org/guide/keras/mixed_precision) for more information 325 on how to use mixed precision. 326 327 Args: 328 inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap. 329 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to 330 True. If True, the loss scale will be dynamically updated over time using 331 an algorithm that keeps the loss scale at approximately its optimal value. 332 If False, a single fixed loss scale is used and `initial_scale` must be 333 specified, which is used as the loss scale. Recommended to keep as True, 334 as choosing a fixed loss scale can be tricky. Currently, there is a small 335 performance overhead to dynamic loss scaling compared to fixed loss 336 scaling. 337 initial_scale: The initial loss scale. If `dynamic` is True, this defaults 338 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as 339 the sole loss scale, as the loss scale does not change over time. When 340 dynamic loss scaling is used, is better for this to be a very high number, 341 because a loss scale that is too high gets lowered far more quickly than a 342 loss scale that is too low gets raised. 343 dynamic_growth_steps: With dynamic loss scaling, every 344 `dynamic_growth_steps` steps with finite gradients, the loss scale is 345 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the 346 count is reset back to zero, gradients are skipped that step, and the loss 347 scale is halved. The count can be queried with 348 `LossScaleOptimizer.dynamic_counter`. This argument can only be specified 349 if `dynamic` is True. 350 351 `LossScaleOptimizer` will occasionally skip applying gradients to the 352 variables, in which case the trainable variables will not change that step. 353 This is done because the dynamic loss scale will sometimes be raised too 354 high, causing overflow in the gradients. Typically, the first 2 to 15 steps of 355 the model are skipped as the initial loss scale is very high, but afterwards 356 steps will only be skipped on average 0.05% of the time (the fraction of steps 357 skipped is `1 / dynamic_growth_steps`). 358 359 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner 360 optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales 361 the loss and unscales the gradients. In methods `minimize` and 362 `apply_gradients`, it additionally updates the loss scale and skips applying 363 gradients if any gradient has a nonfinite value. 364 365 ### Hyperparameters 366 367 Hyperparameters can be accessed and set on the LossScaleOptimizer, which will 368 be delegated to the wrapped optimizer. 369 370 >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) 371 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 372 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 373 0.8 374 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` 375 >>> opt.beta_1 376 0.7 377 >>> opt.inner_optimizer.beta_1 378 0.7 379 380 However, accessing or setting non-hyperparameters is not delegated to the 381 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but 382 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on 383 `beta_1`. 384 385 >>> opt.inner_optimizer.epsilon 386 1e-5 387 >>> opt.epsilon 388 Traceback (most recent call last): 389 ... 390 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' 391 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` 392 >>> opt.inner_optimizer.epsilon 393 >>> 1e-5 394 395 In the above example, despite epsilon being set on the LossScaleOptimizer, the 396 old epsilon value will still be used when training as epsilon was not set on 397 the inner optimizer. 398 """ 399 400 _HAS_AGGREGATE_GRAD = True 401 402 def __init__(self, inner_optimizer, dynamic=True, initial_scale=None, 403 dynamic_growth_steps=None): 404 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 405 raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, ' 406 'but got: %s' % inner_optimizer) 407 if not isinstance(dynamic, bool): 408 # Catch errors if a user incorrectly passes a string or float to the 409 # second argument argument, as this is commonly done for 410 # LossScaleOptimizerV1. 411 raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' 412 'be a bool, but got: %r' % (dynamic,)) 413 if isinstance(inner_optimizer, LossScaleOptimizer): 414 raise TypeError('LossScaleOptimizer cannot wrap another ' 415 'LossScaleOptimizer, but got: %s' % (inner_optimizer,)) 416 self._raise_if_strategy_unsupported() 417 if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False): 418 # TODO(reedwm): Maybe support this. The difficulty is that LSO has the 419 # same checkpoint format as the inner optimizer, so multiple LSOs wrapping 420 # the same optimizer causes the checkpointing logic to become confused. 421 raise ValueError('"inner_optimizer" is already wrapped by a ' 422 'LossScaleOptimizer. An optimizer can only be wrapped ' 423 'by a single LossScaleOptimizer') 424 self._optimizer = inner_optimizer 425 self._optimizer._is_wrapped_by_loss_scale_optimizer = True 426 427 # We don't call super().__init__, since we do not want to call OptimizerV2's 428 # constructor. 429 base_delegate.DelegatingTrackableMixin.__init__(self, self._optimizer) 430 431 if dynamic: 432 if initial_scale is None: 433 initial_scale = _DEFAULT_INITIAL_SCALE 434 if dynamic_growth_steps is None: 435 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 436 self._loss_scale = _DynamicLossScaleState( 437 initial_scale, dynamic_growth_steps, multiplier=2) 438 self._track_trackable(self._loss_scale, 'loss_scale') 439 else: 440 if initial_scale is None: 441 raise ValueError('"initial_scale" must be specified if "dynamic" is ' 442 'False') 443 self._loss_scale = float(initial_scale) 444 if dynamic_growth_steps is not None: 445 raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' 446 'is False, but got: %s' % (dynamic_growth_steps,)) 447 448 # To support restoring TensorFlow 2.2 checkpoints. 449 self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 450 'base_optimizer') 451 452 @property 453 def dynamic(self): 454 """Bool indicating whether dynamic loss scaling is used.""" 455 return isinstance(self._loss_scale, _DynamicLossScaleState) 456 457 @property 458 def loss_scale(self): 459 """The current loss scale as a float32 scalar tensor.""" 460 if isinstance(self._loss_scale, _DynamicLossScaleState): 461 return ops.convert_to_tensor_v2_with_dispatch( 462 self._loss_scale.current_loss_scale) 463 else: 464 return ops.convert_to_tensor_v2_with_dispatch(self._loss_scale) 465 466 @property 467 def dynamic_counter(self): 468 """The number of steps since the loss scale was last increased or decreased. 469 470 This is None if `LossScaleOptimizer.dynamic` is False. 471 472 The counter is incremented every step. Once it reaches 473 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled 474 and the counter will be reset back to zero. If nonfinite gradients are 475 encountered, the loss scale will be halved and the counter will be reset 476 back to zero. 477 """ 478 if isinstance(self._loss_scale, _DynamicLossScaleState): 479 return self._loss_scale.counter 480 else: 481 return None 482 483 @property 484 def initial_scale(self): 485 """The initial loss scale. 486 487 If `LossScaleOptimizer.dynamic` is False, this is the same number as 488 `LossScaleOptimizer.loss_scale`, as the loss scale never changes. 489 """ 490 if isinstance(self._loss_scale, _DynamicLossScaleState): 491 return self._loss_scale.initial_loss_scale 492 else: 493 return self._loss_scale 494 495 @property 496 def dynamic_growth_steps(self): 497 """The number of steps it takes to increase the loss scale. 498 499 This is None if `LossScaleOptimizer.dynamic` is False. 500 501 Every `dynamic_growth_steps` consecutive steps with finite gradients, the 502 loss scale is increased. 503 """ 504 if isinstance(self._loss_scale, _DynamicLossScaleState): 505 return self._loss_scale.growth_steps 506 else: 507 return None 508 509 @property 510 def inner_optimizer(self): 511 """The optimizer that this LossScaleOptimizer is wrapping.""" 512 return self._optimizer 513 514 def get_scaled_loss(self, loss): 515 """Scales the loss by the loss scale. 516 517 This method is only needed if you compute gradients manually, e.g. with 518 `tf.GradientTape`. In that case, call this method to scale the loss before 519 passing the loss to `tf.GradientTape`. If you use 520 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 521 scaling is automatically applied and this method is unneeded. 522 523 If this method is called, `get_unscaled_gradients` should also be called. 524 See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for 525 an example. 526 527 Args: 528 loss: The loss, which will be multiplied by the loss scale. Can either be 529 a tensor or a callable returning a tensor. 530 531 Returns: 532 `loss` multiplied by `LossScaleOptimizer.loss_scale`. 533 """ 534 if callable(loss): 535 def new_loss(): 536 loss_val = loss() 537 return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) 538 return new_loss 539 else: 540 return loss * math_ops.cast(self.loss_scale, loss.dtype) 541 542 def get_unscaled_gradients(self, grads): 543 """Unscales the gradients by the loss scale. 544 545 This method is only needed if you compute gradients manually, e.g. with 546 `tf.GradientTape`. In that case, call this method to unscale the gradients 547 after computing them with `tf.GradientTape`. If you use 548 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 549 scaling is automatically applied and this method is unneeded. 550 551 If this method is called, `get_scaled_loss` should also be called. See 552 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an 553 example. 554 555 Args: 556 grads: A list of tensors, each which will be divided by the loss scale. 557 Can have None values, which are ignored. 558 559 Returns: 560 A new list the same size as `grads`, where every non-None value in `grads` 561 is divided by `LossScaleOptimizer.loss_scale`. 562 """ 563 loss_scale_reciprocal = 1. / self.loss_scale 564 return [ 565 _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None 566 for g in grads 567 ] 568 569 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 570 tape = backprop.GradientTape() if tape is None else tape 571 with tape: 572 loss = self.get_scaled_loss(loss) 573 grads_and_vars = self._optimizer._compute_gradients( # pylint: disable=protected-access 574 loss, 575 var_list, 576 grad_loss, 577 tape=tape) 578 grads = [g for g, _ in grads_and_vars] 579 weights = [v for _, v in grads_and_vars] 580 unscaled_grads = self.get_unscaled_gradients(grads) 581 return list(zip(unscaled_grads, weights)) 582 583 def get_gradients(self, loss, params): 584 loss = self.get_scaled_loss(loss) 585 grads = self._optimizer.get_gradients(loss, params) 586 return self.get_unscaled_gradients(grads) 587 588 def _create_all_weights(self, var_list): 589 self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access 590 591 def apply_gradients(self, 592 grads_and_vars, 593 name=None, 594 experimental_aggregate_gradients=True): 595 if distribution_strategy_context.in_cross_replica_context(): 596 raise ValueError('apply_gradients() must be called in a replica context.') 597 # We check for the strategy here despite already checking in the constructor 598 # as frequently the optimizer is created outside the strategy's scope. 599 self._raise_if_strategy_unsupported() 600 601 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 602 if experimental_aggregate_gradients: 603 # We must aggregate the gradients here instead of in 604 # self.optimizer.apply_gradients, so that any NaN or Inf gradients are 605 # propogated to each replica. If any replica has a NaN or Inf gradient, 606 # they must all have a NaN or Inf gradient so that they all skip the step. 607 # pylint: disable=protected-access 608 grads_and_vars = self._optimizer._transform_unaggregated_gradients( 609 grads_and_vars) 610 grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars) 611 # pylint: enable=protected-access 612 613 grads_and_vars = tuple(grads_and_vars) 614 grads = [g for g, _ in grads_and_vars] 615 # We do not want DistributionStrategy to unwrap any MirroredVariables in 616 # grads_and_vars, because even in a replica context, the wrapped 617 # optimizer expects mirrored variables. So we wrap the variables with an 618 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 619 # MirroredVariables. 620 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 621 622 def do_not_apply_fn(): 623 # Normally self._optimizer.iterations is incremented in 624 # self._optimizer.apply_gradients(). Since that is not called in this 625 # branch, we increment it here instead. 626 return self._optimizer.iterations.assign_add(1, read_value=False) 627 628 def _if_should_apply_grads(grads): 629 if isinstance(self._loss_scale, _DynamicLossScaleState): 630 return self._loss_scale.update(grads) 631 else: 632 return (control_flow_ops.no_op(), True) 633 634 if optimizer_utils.strategy_supports_no_merge_call(): 635 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 636 def apply_fn(): 637 return self._apply_gradients(grads, wrapped_vars, name) 638 639 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 640 do_not_apply_fn) 641 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 642 643 else: 644 645 def _apply_gradients_cross_replica(distribution, grads, wrapped_vars, 646 name): 647 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 648 649 def apply_fn(): 650 return distribution.extended.call_for_each_replica( 651 self._apply_gradients, 652 args=(grads, wrapped_vars, name)) 653 654 # Note: We must call this cond() in a cross-replica context. 655 # DistributionStrategy does not support having a cond in a replica 656 # context with a branch that calls `merge_call`, and 657 # self._optimizer.apply_gradients calls `merge_call`. 658 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 659 do_not_apply_fn) 660 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 661 return distribution_strategy_context.get_replica_context().merge_call( 662 _apply_gradients_cross_replica, 663 args=(grads, wrapped_vars, name)) 664 665 def _apply_gradients(self, grads, wrapped_vars, name): 666 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer 667 # already aggregated the gradients. 668 # TODO(reedwm): This will raise a fairly cryptic error message if 669 # self._optimizer.apply_gradients does not take 670 # experimental_aggregate_gradients. 671 return self._optimizer.apply_gradients( 672 list(zip(grads, wrapped_vars.value)), name, 673 experimental_aggregate_gradients=False) 674 675 def get_config(self): 676 serialized_optimizer = optimizers.serialize(self._optimizer) 677 return { 678 'inner_optimizer': serialized_optimizer, 679 'dynamic': self.dynamic, 680 'initial_scale': self.initial_scale, 681 'dynamic_growth_steps': self.dynamic_growth_steps, 682 } 683 684 @classmethod 685 def from_config(cls, config, custom_objects=None): 686 config = config.copy() # Make a copy, since we mutate config 687 if 'loss_scale' in config: 688 # If loss_scale is in config, we assume we are deserializing a 689 # LossScaleOptimizer from TF 2.3 or below. We convert the config so it 690 # can be deserialized in the current LossScaleOptimizer. 691 loss_scale = keras_loss_scale_module.deserialize( 692 config.pop('loss_scale')) 693 if isinstance(loss_scale, loss_scale_module.FixedLossScale): 694 config['dynamic'] = False 695 config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access 696 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 697 config['dynamic'] = True 698 config['initial_scale'] = loss_scale.initial_loss_scale 699 config['dynamic_growth_steps'] = loss_scale.increment_period 700 if loss_scale.multiplier != 2: 701 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 702 'DynamicLossScale whose multiplier is not 2. Got ' 703 'DynamicLossScale: %s' % (loss_scale,)) 704 else: 705 raise ValueError( 706 'Serialized LossScaleOptimizers with a LossScale that is neither a ' 707 'FixedLossScale nor a DynamicLossScale can no longer be ' 708 'deserialized') 709 config['inner_optimizer'] = config.pop('optimizer') 710 config['inner_optimizer'] = optimizers.deserialize( 711 config['inner_optimizer'], custom_objects=custom_objects) 712 return cls(**config) 713 714 def _raise_if_strategy_unsupported(self): 715 if not strategy_supports_loss_scaling(): 716 strategy = distribution_strategy_context.get_strategy() 717 if isinstance(strategy, 718 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 719 tpu_strategy.TPUStrategyV2)): 720 raise ValueError( 721 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 722 'unnecessary with TPUs, since they support bfloat16 instead of ' 723 'float16 and bfloat16 does not require loss scaling. You should ' 724 'remove the use of the LossScaleOptimizer when TPUs are used.') 725 else: 726 raise ValueError('Loss scaling is not supported with the ' 727 'tf.distribute.Strategy: %s. Try using a different ' 728 'Strategy, e.g. a MirroredStrategy' % 729 strategy.__class__.__name__) 730 731 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer 732 # below. 733 734 @property 735 def iterations(self): 736 return self._optimizer.iterations 737 738 @iterations.setter 739 def iterations(self, variable): 740 self._optimizer.iterations = variable 741 742 def get_slot_names(self): 743 return self._optimizer.get_slot_names() 744 745 def variables(self): 746 return self._optimizer.variables() 747 748 @property 749 def weights(self): 750 return self._optimizer.weights 751 752 def get_weights(self): 753 return self._optimizer.get_weights() 754 755 def set_weights(self, weights): 756 return self._optimizer.set_weights(weights) 757 758 @property 759 def clipnorm(self): 760 return self._optimizer.clipnorm 761 762 @clipnorm.setter 763 def clipnorm(self, val): 764 self._optimizer.clipnorm = val 765 766 @property 767 def global_clipnorm(self): 768 return self._optimizer.global_clipnorm 769 770 @global_clipnorm.setter 771 def global_clipnorm(self, val): 772 self._optimizer.global_clipnorm = val 773 774 @property 775 def clipvalue(self): 776 return self._optimizer.clipvalue 777 778 @clipvalue.setter 779 def clipvalue(self, val): 780 self._optimizer.clipvalue = val 781 782 def _aggregate_gradients(self, grads_and_vars): 783 return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access 784 785 def _restore_slot_variable(self, slot_name, variable, slot_variable): 786 return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access 787 slot_variable) 788 789 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 790 variable): 791 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 792 slot_variable_position, slot_name, variable) 793 794 def get_slot(self, var, slot_name): 795 return self._optimizer.get_slot(var, slot_name) 796 797 def add_slot(self, var, slot_name, initializer='zeros'): 798 return self._optimizer.add_slot(var, slot_name, initializer) 799 800 def __getattribute__(self, name): 801 try: 802 return object.__getattribute__(self, name) 803 except AttributeError as e: 804 if name == '_optimizer' or name == '_hyper': 805 # Avoid infinite recursion 806 raise e 807 808 # Delegate hyperparameter accesses to inner optimizer. 809 if name == 'lr': 810 name = 'learning_rate' 811 if name in self._optimizer._hyper: 812 return self._optimizer._get_hyper(name) 813 raise e 814 815 def __dir__(self): 816 result = set(super(LossScaleOptimizer, self).__dir__()) 817 if '_optimizer' in result: 818 result |= self._optimizer._hyper.keys() 819 if 'learning_rate' in self._optimizer._hyper.keys(): 820 result.add('lr') 821 return list(result) 822 823 def __setattr__(self, name, value): 824 if name == 'lr': 825 name = 'learning_rate' 826 # Delegate setting hyperparameter to inner optimizer if the attribute does 827 # not exist on the LossScaleOptimizer 828 try: 829 # We cannot check for the 'iterations' attribute as it cannot be set after 830 # it is accessed. 831 if name != 'iterations': 832 object.__getattribute__(self, name) 833 has_attribute = True 834 except AttributeError: 835 has_attribute = False 836 if (name != '_optimizer' and name in self._optimizer._hyper 837 and not has_attribute): 838 self._optimizer._set_hyper(name, value) 839 else: 840 super(LossScaleOptimizer, self).__setattr__(name, value) 841 842 # Explicitly delegate learning_rate. Normally hyperparameters are delegated in 843 # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper 844 # (e.g. because self._optimizer itself wraps another optimizer), then it won't 845 # be delegated. Since learning_rate is a very commonly accessed 846 # hyperparameter, we delegate it here. 847 @property 848 def learning_rate(self): 849 return self._optimizer.learning_rate 850 851 @learning_rate.setter 852 def learning_rate(self, value): 853 self._optimizer.learning_rate = value 854 855 @property 856 def lr(self): 857 return self._optimizer.learning_rate 858 859 @lr.setter 860 def lr(self, value): 861 self._optimizer.lr = value 862 863 # We do not override some OptimizerV2 methods. For each, we describe why we do 864 # not delegate them to self._optimizer: 865 # * get_updates: get_updates() calls get_gradients(). Since we override 866 # get_gradients(), we cannot delegate get_updates() to self._optimizer, 867 # otherwise the overridden get_gradients() method would not be called. 868 # Luckily, get_updates() does not access any OptimizerV2 fields, so 869 # inheriting the OptimizerV2 version works fine. 870 # * minimize: We don't delegate for a similar as get_updates(): it calls 871 # both self._compute_gradients() and self.apply_gradients(), and both need 872 # to have the LossScaleOptimizer version called. 873 874 # TODO(reedwm): Maybe throw an error if mixed precision is used without this 875 # optimizer being used. 876 877 878@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') 879class LossScaleOptimizerV1(LossScaleOptimizer): 880 """An deprecated optimizer that applies loss scaling. 881 882 Warning: This class is deprecated and will be removed in a future version of 883 TensorFlow. Please use the non-experimental class 884 `tf.keras.mixed_precision.LossScaleOptimizer` instead. 885 886 This class is identical to the non-experimental 887 `keras.mixed_precision.LossScaleOptimizer` except its constructor takes 888 different arguments. For this class (the experimental version), the 889 constructor takes a `loss_scale` argument. For the non-experimental class, 890 the constructor encodes the loss scaling information in multiple arguments. 891 Note that unlike this class, the non-experimental class does not accept a 892 `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. 893 894 If you currently use this class, you should switch to the non-experimental 895 `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several 896 examples of converting the use of the experimental class to the equivalent 897 non-experimental class. 898 899 >>> # In all of the examples below, `opt1` and `opt2` are identical 900 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 901 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 902 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 903 ... tf.keras.optimizers.SGD()) 904 >>> assert opt1.get_config() == opt2.get_config() 905 906 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 907 ... tf.keras.optimizers.SGD(), loss_scale=123) 908 >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 909 >>> # refers to the initial loss scale, which is the single fixed loss scale 910 >>> # when dynamic=False. 911 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 912 ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) 913 >>> assert opt1.get_config() == opt2.get_config() 914 915 >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( 916 ... initial_loss_scale=2048, increment_period=500) 917 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 918 ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) 919 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 920 ... tf.keras.optimizers.SGD(), initial_scale=2048, 921 ... dynamic_growth_steps=500) 922 >>> assert opt1.get_config() == opt2.get_config() 923 924 Make sure to also switch from this class to the non-experimental class in 925 isinstance checks, if you have any. If you do not do this, your model may run 926 into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses 927 the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to 928 switch isinstance checks to the non-experimental `LossScaleOptimizer` even 929 before using the non-experimental `LossScaleOptimizer`. 930 931 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 932 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 933 >>> # The experimental class subclasses the non-experimental class 934 >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer) 935 True 936 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 937 ... tf.keras.optimizers.SGD()) 938 >>> # The non-experimental class does NOT subclass the experimental class. 939 >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer) 940 False 941 942 Args: 943 optimizer: The Optimizer instance to wrap. 944 loss_scale: The loss scale to scale the loss and gradients. This can 945 either be an int/float to use a fixed loss scale, the string "dynamic" 946 to use dynamic loss scaling, or an instance of a LossScale. The string 947 "dynamic" equivalent to passing `DynamicLossScale()`, and passing an 948 int/float is equivalent to passing a FixedLossScale with the given loss 949 scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must 950 be 2 (the default). 951 """ 952 953 def __init__(self, optimizer, loss_scale): 954 warn_msg_prefix = ( 955 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 956 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 957 'instead. ') 958 959 if isinstance(loss_scale, dict): 960 loss_scale = keras_loss_scale_module.deserialize(loss_scale) 961 962 if isinstance(loss_scale, (int, float)): 963 tf_logging.warning( 964 warn_msg_prefix + 'For example:\n' 965 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 966 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) 967 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 968 initial_scale=loss_scale) 969 elif isinstance(loss_scale, loss_scale_module.FixedLossScale): 970 ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access 971 tf_logging.warning( 972 warn_msg_prefix + 'For example:\n' 973 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 974 'opt, dynamic=False, initial_scale={})'.format(ls_val)) 975 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 976 initial_scale=ls_val) 977 elif loss_scale == 'dynamic': 978 tf_logging.warning( 979 warn_msg_prefix + 'For example:\n' 980 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 981 'opt)') 982 super(LossScaleOptimizerV1, self).__init__(optimizer) 983 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 984 kwargs = {} 985 extra_arguments = '' 986 if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: 987 kwargs['initial_scale'] = loss_scale.initial_loss_scale 988 extra_arguments += (', initial_scale=%s' % 989 loss_scale.initial_loss_scale) 990 if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: 991 kwargs['dynamic_growth_steps'] = loss_scale.increment_period 992 extra_arguments += (', dynamic_growth_steps=%s' % 993 loss_scale.increment_period) 994 if loss_scale.multiplier != 2: 995 raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 996 'DynamicLossScale.multiplier must be 2. Got: %s' 997 % (loss_scale,)) 998 tf_logging.warning( 999 warn_msg_prefix + 1000 'Note that the non-experimental LossScaleOptimizer does not take a ' 1001 'DynamicLossScale but instead takes the dynamic configuration ' 1002 'directly in the constructor. For example:\n' 1003 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 1004 'opt{})\n'.format(extra_arguments)) 1005 super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) 1006 elif isinstance(loss_scale, loss_scale_module.LossScale): 1007 raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 1008 'DynamicLossScale is no longer supported. Got: {}' 1009 .format(loss_scale)) 1010 else: 1011 raise ValueError('Invalid value passed to loss_scale. loss_scale ' 1012 'must be the string "dynamic" (recommended), an int, ' 1013 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 1014 'value: {}'.format(loss_scale)) 1015 1016 @classmethod 1017 def from_config(cls, config, custom_objects=None): 1018 config = config.copy() # Make a copy, since we mutate config 1019 1020 # If loss_scale is in config, we assume we are deserializing a 1021 # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are 1022 # deserializing a LossScaleOptimizer from TF 2.4 or above. 1023 if 'loss_scale' in config: 1024 config['loss_scale'] = keras_loss_scale_module.deserialize( 1025 config['loss_scale']) 1026 if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) 1027 and config['loss_scale'].multiplier != 2): 1028 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 1029 'DynamicLossScale whose multiplier is not 2. Got ' 1030 'DynamicLossScale: %s' % (config['loss_scale'],)) 1031 config['optimizer'] = optimizers.deserialize( 1032 config['optimizer'], custom_objects=custom_objects) 1033 return cls(**config) 1034 1035 # We convert the config, as generated by LossScaleOptimizer.get_config, to a 1036 # version that can be passed to LossScaleOptimizerV1.__init__ 1037 if config['dynamic']: 1038 config['loss_scale'] = loss_scale_module.DynamicLossScale( 1039 config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) 1040 else: 1041 config['loss_scale'] = loss_scale_module.FixedLossScale( 1042 config['initial_scale']) 1043 1044 del config['dynamic'] 1045 del config['initial_scale'] 1046 del config['dynamic_growth_steps'] 1047 config['optimizer'] = optimizers.deserialize( 1048 config.pop('inner_optimizer'), custom_objects=custom_objects) 1049 return cls(**config) 1050 1051 1052class FakeOptimizerForRestoration(trackable.Trackable): 1053 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints. 1054 1055 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class 1056 exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow. 1057 1058 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the 1059 following in LossScaleOptimizer.__init__ 1060 1061 ``` 1062 self._track_trackable(self._optimizer, 'base_optimizer') 1063 ``` 1064 1065 This means a dependency from the LossScaleOptimizer to the wrapped optimizer 1066 would be stored in the checkpoint. However now, the checkpoint format with a 1067 LossScaleOptimizer is the same as the format without a LossScaleOptimizer, 1068 except the loss scale is also stored. This means there is no dependency from 1069 the LossScaleOptimizer to the wrapped optimizer. Instead, the 1070 LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's 1071 perspective, by overriding all Trackable methods and delegating them to the 1072 wrapped optimizer. 1073 1074 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency 1075 on this class instead of the inner optimizer. When restored, this class will 1076 instead restore the slot variables of the inner optimizer. Since this class 1077 has no variables, it does not affect the checkpoint when saved. 1078 """ 1079 1080 def __init__(self, optimizer): 1081 self._optimizer = optimizer 1082 1083 def get_slot_names(self): 1084 return self._optimizer.get_slot_names() 1085 1086 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 1087 variable): 1088 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 1089 slot_variable_position, slot_name, variable) 1090 1091 1092mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2, 1093 LossScaleOptimizerV1) 1094 1095 1096def _multiply_gradient(gradient, scale): 1097 """Multiply a (possibly sparse) gradient by the given scale factor.""" 1098 scale = math_ops.cast(scale, gradient.dtype) 1099 if isinstance(gradient, indexed_slices.IndexedSlices): 1100 return indexed_slices.IndexedSlices( 1101 gradient.values * scale, 1102 gradient.indices, 1103 dense_shape=gradient.dense_shape) 1104 else: 1105 return gradient * scale 1106 1107 1108def strategy_supports_loss_scaling(): 1109 """Returns True if the current Strategy supports loss scaling.""" 1110 if not distribution_strategy_context.has_strategy(): 1111 return True 1112 strategy = distribution_strategy_context.get_strategy() 1113 # Strategies are supported if either there is only one replica or if variables 1114 # are replicated per device. Otherwise, the current model.fit() implementation 1115 # and most custom training loops incorrectly unscale the gradients. Currently, 1116 # gradients are unscaled once per compute replica, but they should be unscaled 1117 # once per variable replica. When there is one variable replica for each 1118 # compute replica, this works fine, but otherwise issues will occur. 1119 # TODO(reedwm): Support all strategies. 1120 return isinstance(strategy, ( 1121 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 1122 collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, 1123 one_device_strategy.OneDeviceStrategy, 1124 one_device_strategy.OneDeviceStrategyV1, 1125 mirrored_strategy.MirroredStrategy, 1126 mirrored_strategy.MirroredStrategyV1, 1127 )) 1128