1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Version 2 of class Optimizer.""" 16# pylint: disable=g-bad-name 17 18import abc 19import contextlib 20import functools 21import warnings 22 23from tensorflow.python.distribute import central_storage_strategy 24from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 25from tensorflow.python.distribute import parameter_server_strategy 26from tensorflow.python.distribute import parameter_server_strategy_v2 27from tensorflow.python.distribute import values as ds_values 28from tensorflow.python.eager import backprop 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import indexed_slices 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.keras import backend 35from tensorflow.python.keras import initializers 36from tensorflow.python.keras.engine import base_layer_utils 37from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 38from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 39from tensorflow.python.keras.utils import generic_utils 40from tensorflow.python.keras.utils import layer_utils 41from tensorflow.python.keras.utils import tf_inspect 42from tensorflow.python.keras.utils import tf_utils 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import gen_resource_variable_ops 46from tensorflow.python.ops import gradients 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import variables as tf_variables 49from tensorflow.python.saved_model import revived_types 50from tensorflow.python.trackable import base as trackable 51from tensorflow.python.util import nest 52from tensorflow.python.util.tf_export import keras_export 53 54 55_DEFAULT_VALID_DTYPES = frozenset([ 56 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, 57 dtypes.complex64, dtypes.complex128 58]) 59 60 61def _deduplicate_indexed_slices(values, indices): 62 """Sums `values` associated with any non-unique `indices`. 63 64 Args: 65 values: A `Tensor` with rank >= 1. 66 indices: A one-dimensional integer `Tensor`, indexing into the first 67 dimension of `values` (as in an IndexedSlices object). 68 69 Returns: 70 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 71 de-duplicated version of `indices` and `summed_values` contains the sum of 72 `values` slices associated with each unique index. 73 """ 74 unique_indices, new_index_positions = array_ops.unique(indices) 75 summed_values = math_ops.unsorted_segment_sum( 76 values, new_index_positions, 77 array_ops.shape(unique_indices)[0]) 78 return (summed_values, unique_indices) 79 80 81class NullContextmanager(object): 82 83 def __init__(self, *args, **kwargs): 84 pass 85 86 def __enter__(self): 87 pass 88 89 def __exit__(self, type_arg, value_arg, traceback_arg): 90 return False # False values do not suppress exceptions 91 92 93def name_scope_only_in_function_or_graph(name): 94 """Internal-only entry point for `name_scope*`. 95 96 Enters a compat.v1.name_scope only when in a function or graph, 97 not when running fully eagerly. 98 99 Args: 100 name: The name argument that is passed to the op function. 101 102 Returns: 103 `name_scope*` context manager. 104 """ 105 if not context.executing_eagerly(): 106 return ops.name_scope_v1(name) 107 else: 108 return NullContextmanager() 109 110 111@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta) 112class OptimizerV2(trackable.Trackable): 113 """Base class for Keras optimizers. 114 115 You should not use this class directly, but instead instantiate one of its 116 subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc. 117 118 ### Usage 119 120 ```python 121 # Create an optimizer with the desired parameters. 122 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 123 # `loss` is a callable that takes no argument and returns the value 124 # to minimize. 125 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 126 # In graph mode, returns op that minimizes the loss by updating the listed 127 # variables. 128 opt_op = opt.minimize(loss, var_list=[var1, var2]) 129 opt_op.run() 130 # In eager mode, simply call minimize to update the list of variables. 131 opt.minimize(loss, var_list=[var1, var2]) 132 ``` 133 134 ### Usage in custom training loops 135 136 In Keras models, sometimes variables are created when the model is first 137 called, instead of construction time. Examples include 1) sequential models 138 without input shape pre-defined, or 2) subclassed models. Pass var_list as 139 callable in these cases. 140 141 Example: 142 143 ```python 144 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 145 model = tf.keras.Sequential() 146 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 147 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 148 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 149 var_list_fn = lambda: model.trainable_weights 150 for input, output in data: 151 opt.minimize(loss_fn, var_list_fn) 152 ``` 153 154 ### Processing gradients before applying them 155 156 Calling `minimize()` takes care of both computing the gradients and 157 applying them to the variables. If you want to process the gradients 158 before applying them you can instead use the optimizer in three steps: 159 160 1. Compute the gradients with `tf.GradientTape`. 161 2. Process the gradients as you wish. 162 3. Apply the processed gradients with `apply_gradients()`. 163 164 Example: 165 166 ```python 167 # Create an optimizer. 168 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 169 170 # Compute the gradients for a list of variables. 171 with tf.GradientTape() as tape: 172 loss = <call_loss_function> 173 vars = <list_of_variables> 174 grads = tape.gradient(loss, vars) 175 176 # Process the gradients, for example cap them, etc. 177 # capped_grads = [MyCapper(g) for g in grads] 178 processed_grads = [process_gradient(g) for g in grads] 179 180 # Ask the optimizer to apply the processed gradients. 181 opt.apply_gradients(zip(processed_grads, var_list)) 182 ``` 183 184 ### Use with `tf.distribute.Strategy` 185 186 This optimizer class is `tf.distribute.Strategy` aware, which means it 187 automatically sums gradients across all replicas. To average gradients, 188 you divide your loss by the global batch size, which is done 189 automatically if you use `tf.keras` built-in training or evaluation loops. 190 See the `reduction` argument of your loss which should be set to 191 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 192 `tf.keras.losses.Reduction.SUM` for not. 193 194 To aggregate gradients yourself, call `apply_gradients` with 195 `experimental_aggregate_gradients` set to False. This is useful if you need to 196 process aggregated gradients. 197 198 If you are not using these and you want to average gradients, you should use 199 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 200 global batch size. Note that when using `tf.distribute.Strategy`, the first 201 component of a tensor's shape is the *replica-local* batch size, which is off 202 by a factor equal to the number of replicas being used to compute a single 203 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 204 resulting in gradients that can be many times too big. 205 206 ### Variable Constraints 207 208 All Keras optimizers respect variable constraints. If constraint function is 209 passed to any variable, the constraint will be applied to the variable after 210 the gradient has been applied to the variable. 211 Important: If gradient is sparse tensor, variable constraint is not supported. 212 213 ### Thread Compatibility 214 215 The entire optimizer is currently thread compatible, not thread-safe. The user 216 needs to perform synchronization if necessary. 217 218 ### Slots 219 220 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 221 additional variables associated with the variables to train. These are called 222 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 223 the slots that it uses. Once you have a slot name you can ask the optimizer 224 for the variable it created to hold the slot value. 225 226 This can be useful if you want to log debug a training algorithm, report stats 227 about the slots, etc. 228 229 ### Hyperparameters 230 231 These are arguments passed to the optimizer subclass constructor 232 (the `__init__` method), and then passed to `self._set_hyper()`. 233 They can be either regular Python values (like 1.0), tensors, or 234 callables. If they are callable, the callable will be called during 235 `apply_gradients()` to get the value for the hyper parameter. 236 237 Hyperparameters can be overwritten through user code: 238 239 Example: 240 241 ```python 242 # Create an optimizer with the desired parameters. 243 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 244 # `loss` is a callable that takes no argument and returns the value 245 # to minimize. 246 loss = lambda: 3 * var1 + 2 * var2 247 # In eager mode, simply call minimize to update the list of variables. 248 opt.minimize(loss, var_list=[var1, var2]) 249 # update learning rate 250 opt.learning_rate = 0.05 251 opt.minimize(loss, var_list=[var1, var2]) 252 ``` 253 254 ### Callable learning rate 255 256 Optimizer accepts a callable learning rate in two ways. The first way is 257 through built-in or customized 258 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be 259 called on each iteration with `schedule(iteration)`, a `tf.Variable` 260 owned by the optimizer. 261 262 Example: 263 264 >>> var = tf.Variable(np.random.random(size=(1,))) 265 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 266 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 267 >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate) 268 >>> loss = lambda: 3 * var 269 >>> opt.minimize(loss, var_list=[var]) 270 <tf.Variable... 271 272 The second way is through a callable function that 273 does not accept any arguments. 274 275 Example: 276 277 >>> var = tf.Variable(np.random.random(size=(1,))) 278 >>> def lr_callable(): 279 ... return .1 280 >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable) 281 >>> loss = lambda: 3 * var 282 >>> opt.minimize(loss, var_list=[var]) 283 <tf.Variable... 284 285 ### Creating a custom optimizer 286 287 If you intend to create your own optimization algorithm, simply inherit from 288 this class and override the following methods: 289 290 - `_resource_apply_dense` (update variable given gradient tensor is a dense 291 `tf.Tensor`) 292 - `_resource_apply_sparse` (update variable given gradient tensor is a 293 sparse `tf.IndexedSlices`. The most common way for this to happen 294 is if you are taking the gradient through a `tf.gather`.) 295 - `_create_slots` 296 (if your optimizer algorithm requires additional variables) 297 - `get_config` 298 (serialization of the optimizer, include all hyper parameters) 299 """ 300 301 # Subclasses should set this to True unless they override `apply_gradients` 302 # with a version that does not have the `experimental_aggregate_gradients` 303 # argument. Older versions of Keras did not have this argument so custom 304 # optimizers may have overridden `apply_gradients` without the 305 # `experimental_aggregate_gradients` argument. Keras only passes 306 # `experimental_aggregate_gradients` if this attribute is True. 307 # Note: This attribute will likely be removed in an upcoming release. 308 _HAS_AGGREGATE_GRAD = False 309 310 def __init__(self, 311 name, 312 gradient_aggregator=None, 313 gradient_transformers=None, 314 **kwargs): 315 """Create a new Optimizer. 316 317 This must be called by the constructors of subclasses. 318 Note that Optimizer instances should not bind to a single graph, 319 and so shouldn't keep Tensors as member variables. Generally 320 you should be able to use the _set_hyper()/state.get_hyper() 321 facility instead. 322 323 This class is stateful and thread-compatible. 324 325 Example of custom gradient transformations: 326 327 ```python 328 def my_gradient_transformer(grads_and_vars): 329 # Simple example, double the gradients. 330 return [(2. * g, v) for g, v in grads_and_vars] 331 332 optimizer = tf.keras.optimizers.SGD( 333 1e-3, gradient_transformers=[my_gradient_transformer]) 334 ``` 335 336 Args: 337 name: String. The name to use for momentum accumulator weights created 338 by the optimizer. 339 gradient_aggregator: The function to use to aggregate gradients across 340 devices (when using `tf.distribute.Strategy`). If `None`, defaults to 341 summing the gradients across devices. The function should accept and 342 return a list of `(gradient, variable)` tuples. 343 gradient_transformers: Optional. List of functions to use to transform 344 gradients before applying updates to Variables. The functions are 345 applied after `gradient_aggregator`. The functions should accept and 346 return a list of `(gradient, variable)` tuples. 347 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 348 `clipnorm`, `global_clipnorm`. 349 If `clipvalue` (float) is set, the gradient of each weight 350 is clipped to be no higher than this value. 351 If `clipnorm` (float) is set, the gradient of each weight 352 is individually clipped so that its norm is no higher than this value. 353 If `global_clipnorm` (float) is set the gradient of all weights is 354 clipped so that their global norm is no higher than this value. 355 356 Raises: 357 ValueError: in case of any invalid argument. 358 """ 359 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"} 360 for k in kwargs: 361 if k not in allowed_kwargs: 362 raise TypeError("Unexpected keyword argument " 363 "passed to optimizer: " + str(k)) 364 # checks that all keyword arguments are non-negative. 365 if kwargs[k] is not None and kwargs[k] < 0: 366 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 367 if k == "lr": 368 warnings.warn( 369 "The `lr` argument is deprecated, use `learning_rate` instead.") 370 371 self._use_locking = True 372 self._init_set_name(name) 373 self._hyper = {} 374 # dict: {variable name : {slot name : variable}} 375 self._slots = {} 376 self._slot_names = [] 377 self._weights = [] 378 self._iterations = None 379 380 # For implementing Trackable. Stores information about how to restore 381 # slot variables which have not yet been created 382 # (trackable._CheckpointPosition objects). 383 # {slot_name : 384 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 385 # ... } 386 self._deferred_slot_restorations = {} 387 388 decay = kwargs.pop("decay", 0.0) 389 if decay < 0.: 390 raise ValueError("decay cannot be less than 0: {}".format(decay)) 391 self._initial_decay = decay 392 393 self._hypers_created = False 394 # Store the distribution strategy object if the optimizer is created inside 395 # strategy scope, so it could be used to create variables later. 396 if distribute_ctx.has_strategy(): 397 self._distribution_strategy = distribute_ctx.get_strategy() 398 else: 399 self._distribution_strategy = None 400 401 # Configure gradient transformations. 402 if gradient_aggregator is None: 403 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients 404 self.gradient_aggregator = gradient_aggregator 405 if gradient_transformers is None: 406 gradient_transformers = [] 407 self.gradient_transformers = gradient_transformers 408 self.clipnorm = kwargs.pop("clipnorm", None) 409 self.global_clipnorm = kwargs.pop("global_clipnorm", None) 410 if self.clipnorm is not None and self.global_clipnorm is not None: 411 raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, " 412 "passed `clipnorm` {}, `global_clipnorm` {}".format( 413 self.clipnorm, self.global_clipnorm)) 414 self.clipvalue = kwargs.pop("clipvalue", None) 415 416 @property 417 def clipnorm(self): 418 """`float` or `None`. If set, clips gradients to a maximum norm.""" 419 return self._clipnorm 420 421 @property 422 def global_clipnorm(self): 423 """`float` or `None`. If set, clips gradients to a maximum norm.""" 424 return self._global_clipnorm 425 426 @clipnorm.setter 427 def clipnorm(self, val): 428 if val is not None and self.gradient_transformers: 429 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 430 "is set. Instead, use the `gradient_transformers` to " 431 "specify clipping and other transformations.") 432 self._clipnorm = val 433 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn( 434 self._clipnorm) 435 436 @global_clipnorm.setter 437 def global_clipnorm(self, val): 438 if val is not None and self.gradient_transformers: 439 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 440 "is set. Instead, use the `gradient_transformers` to " 441 "specify clipping and other transformations.") 442 self._global_clipnorm = val 443 self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn( 444 self._global_clipnorm) 445 446 @property 447 def clipvalue(self): 448 """`float` or `None`. If set, clips gradients to a maximum value.""" 449 return self._clipvalue 450 451 @clipvalue.setter 452 def clipvalue(self, val): 453 if val is not None and self.gradient_transformers: 454 raise ValueError("`clipvalue` cannot be set when `gradient_transformers` " 455 "is set. Instead, use the `gradient_transformers` to " 456 "specify clipping and other transformations.") 457 self._clipvalue = val 458 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn( 459 self._clipvalue) 460 461 def _transform_loss(self, loss): 462 """Called in `.minimize` to transform loss before computing gradients.""" 463 return loss 464 465 def _get_gradients(self, tape, loss, var_list, grad_loss=None): 466 """Called in `minimize` to compute gradients from loss.""" 467 grads = tape.gradient(loss, var_list, grad_loss) 468 return list(zip(grads, var_list)) 469 470 def _transform_unaggregated_gradients(self, grads_and_vars): 471 """Called in `apply_gradients` before gradient aggregation.""" 472 return grads_and_vars 473 474 def _aggregate_gradients(self, grads_and_vars): 475 """Called in `apply_gradients` to aggregate gradients across devices. 476 477 Note that user subclasses may override this, so the interface should not be 478 changed. 479 480 Args: 481 grads_and_vars: List of (gradient, variable) pairs. 482 483 Returns: 484 A list of (aggregrated_gradient, variable) pairs. By default, this calls 485 `self.gradient_aggregator`. 486 """ 487 return self.gradient_aggregator(grads_and_vars) 488 489 def _transform_gradients(self, grads_and_vars): 490 """Called in `apply_gradients` after aggregation.""" 491 if self._clipvalue is not None: 492 grads_and_vars = self._clipvalue_fn(grads_and_vars) 493 if self._clipnorm is not None: 494 grads_and_vars = self._clipnorm_fn(grads_and_vars) 495 if self._global_clipnorm is not None: 496 grads_and_vars = self._global_clipnorm_fn(grads_and_vars) 497 498 for fn in self.gradient_transformers: 499 grads_and_vars = fn(grads_and_vars) 500 return grads_and_vars 501 502 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None): 503 """Minimize `loss` by updating `var_list`. 504 505 This method simply computes gradient using `tf.GradientTape` and calls 506 `apply_gradients()`. If you want to process the gradient before applying 507 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 508 of using this function. 509 510 Args: 511 loss: `Tensor` or callable. If a callable, `loss` should take no arguments 512 and return the value to minimize. If a `Tensor`, the `tape` argument 513 must be passed. 514 var_list: list or tuple of `Variable` objects to update to minimize 515 `loss`, or a callable returning the list or tuple of `Variable` objects. 516 Use callable when the variable list would otherwise be incomplete before 517 `minimize` since the variables are created at the first time `loss` is 518 called. 519 grad_loss: (Optional). A `Tensor` holding the gradient computed for 520 `loss`. 521 name: (Optional) str. Name for the returned operation. 522 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 523 the tape that computed the `loss` must be provided. 524 525 Returns: 526 An `Operation` that updates the variables in `var_list`. The `iterations` 527 will be automatically increased by 1. 528 529 Raises: 530 ValueError: If some of the variables are not `Variable` objects. 531 532 """ 533 grads_and_vars = self._compute_gradients( 534 loss, var_list=var_list, grad_loss=grad_loss, tape=tape) 535 return self.apply_gradients(grads_and_vars, name=name) 536 537 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 538 """Compute gradients of `loss` for the variables in `var_list`. 539 540 This is the first part of `minimize()`. It returns a list 541 of (gradient, variable) pairs where "gradient" is the gradient 542 for "variable". Note that "gradient" can be a `Tensor`, an 543 `IndexedSlices`, or `None` if there is no gradient for the 544 given variable. 545 546 Args: 547 loss: `Tensor` or callable. If a callable, `loss` should take no 548 arguments and return the value to minimize. If a `Tensor`, the `tape` 549 argument must be passed. 550 var_list: list or tuple of `Variable` objects to update to minimize 551 `loss`, or a callable returning the list or tuple of `Variable` objects. 552 Use callable when the variable list would otherwise be incomplete before 553 `minimize` and the variables are created at the first time when `loss` 554 is called. 555 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 556 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 557 the tape that computed the `loss` must be provided. 558 559 Returns: 560 A list of (gradient, variable) pairs. Variable is always present, but 561 gradient can be `None`. 562 563 Raises: 564 TypeError: If `var_list` contains anything else than `Variable` objects. 565 ValueError: If some arguments are invalid, or var_list is None. 566 """ 567 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 568 if not callable(loss) and tape is None: 569 raise ValueError("`tape` is required when a `Tensor` loss is passed.") 570 tape = tape if tape is not None else backprop.GradientTape() 571 572 if callable(loss): 573 with tape: 574 if not callable(var_list): 575 tape.watch(var_list) 576 loss = loss() 577 if callable(var_list): 578 var_list = var_list() 579 580 with tape: 581 loss = self._transform_loss(loss) 582 583 var_list = nest.flatten(var_list) 584 with ops.name_scope_v2(self._name + "/gradients"): 585 grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss) 586 587 self._assert_valid_dtypes([ 588 v for g, v in grads_and_vars 589 if g is not None and v.dtype != dtypes.resource 590 ]) 591 592 return grads_and_vars 593 594 def apply_gradients(self, 595 grads_and_vars, 596 name=None, 597 experimental_aggregate_gradients=True): 598 """Apply gradients to variables. 599 600 This is the second part of `minimize()`. It returns an `Operation` that 601 applies gradients. 602 603 The method sums gradients from all replicas in the presence of 604 `tf.distribute.Strategy` by default. You can aggregate gradients yourself by 605 passing `experimental_aggregate_gradients=False`. 606 607 Example: 608 609 ```python 610 grads = tape.gradient(loss, vars) 611 grads = tf.distribute.get_replica_context().all_reduce('sum', grads) 612 # Processing aggregated gradients. 613 optimizer.apply_gradients(zip(grads, vars), 614 experimental_aggregate_gradients=False) 615 616 ``` 617 618 Args: 619 grads_and_vars: List of (gradient, variable) pairs. 620 name: Optional name for the returned operation. Default to the name passed 621 to the `Optimizer` constructor. 622 experimental_aggregate_gradients: Whether to sum gradients from different 623 replicas in the presense of `tf.distribute.Strategy`. If False, it's 624 user responsibility to aggregate the gradients. Default to True. 625 626 Returns: 627 An `Operation` that applies the specified gradients. The `iterations` 628 will be automatically increased by 1. 629 630 Raises: 631 TypeError: If `grads_and_vars` is malformed. 632 ValueError: If none of the variables have gradients. 633 RuntimeError: If called in a cross-replica context. 634 """ 635 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 636 var_list = [v for (_, v) in grads_and_vars] 637 638 with ops.name_scope_v2(self._name): 639 # Create iteration if necessary. 640 with ops.init_scope(): 641 self._create_all_weights(var_list) 642 643 if not grads_and_vars: 644 # Distribution strategy does not support reducing an empty list of 645 # gradients 646 return control_flow_ops.no_op() 647 648 if distribute_ctx.in_cross_replica_context(): 649 raise RuntimeError( 650 "`apply_gradients() cannot be called in cross-replica context. " 651 "Use `tf.distribute.Strategy.run` to enter replica " 652 "context.") 653 654 strategy = distribute_ctx.get_strategy() 655 if (not experimental_aggregate_gradients and strategy and 656 isinstance(strategy, 657 (parameter_server_strategy.ParameterServerStrategyV1, 658 parameter_server_strategy_v2.ParameterServerStrategyV2, 659 central_storage_strategy.CentralStorageStrategy, 660 central_storage_strategy.CentralStorageStrategyV1))): 661 raise NotImplementedError( 662 "`experimental_aggregate_gradients=False is not supported for " 663 "ParameterServerStrategy and CentralStorageStrategy") 664 665 apply_state = self._prepare(var_list) 666 if experimental_aggregate_gradients: 667 grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars) 668 grads_and_vars = self._aggregate_gradients(grads_and_vars) 669 grads_and_vars = self._transform_gradients(grads_and_vars) 670 671 if optimizer_utils.strategy_supports_no_merge_call(): 672 return self._distributed_apply(strategy, grads_and_vars, name, 673 apply_state) 674 else: 675 return distribute_ctx.get_replica_context().merge_call( 676 functools.partial(self._distributed_apply, apply_state=apply_state), 677 args=(grads_and_vars,), 678 kwargs={ 679 "name": name, 680 }) 681 682 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): 683 """`apply_gradients` using a `DistributionStrategy`.""" 684 685 def apply_grad_to_update_var(var, grad): 686 """Apply gradient to variable.""" 687 if isinstance(var, ops.Tensor): 688 raise NotImplementedError("Trying to update a Tensor ", var) 689 690 apply_kwargs = {} 691 if isinstance(grad, indexed_slices.IndexedSlices): 692 if var.constraint is not None: 693 raise RuntimeError( 694 "Cannot use a constraint function on a sparse variable.") 695 if "apply_state" in self._sparse_apply_args: 696 apply_kwargs["apply_state"] = apply_state 697 return self._resource_apply_sparse_duplicate_indices( 698 grad.values, var, grad.indices, **apply_kwargs) 699 700 if "apply_state" in self._dense_apply_args: 701 apply_kwargs["apply_state"] = apply_state 702 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 703 if var.constraint is not None: 704 with ops.control_dependencies([update_op]): 705 return var.assign(var.constraint(var)) 706 else: 707 return update_op 708 709 eagerly_outside_functions = ops.executing_eagerly_outside_functions() 710 update_ops = [] 711 with name_scope_only_in_function_or_graph(name or self._name): 712 for grad, var in grads_and_vars: 713 # Colocate the update with variables to avoid unnecessary communication 714 # delays. See b/136304694. 715 with distribution.extended.colocate_vars_with(var): 716 with name_scope_only_in_function_or_graph( 717 "update" if eagerly_outside_functions else "update_" + 718 var.op.name): 719 update_op = distribution.extended.update( 720 var, apply_grad_to_update_var, args=(grad,), group=False) 721 if distribute_ctx.in_cross_replica_context(): 722 # In cross-replica context, extended.update returns a list of 723 # update ops from all replicas (group=False). 724 update_ops.extend(update_op) 725 else: 726 # In replica context, extended.update return the single update op 727 # of current replica. 728 update_ops.append(update_op) 729 730 any_symbolic = any(isinstance(i, ops.Operation) or 731 tf_utils.is_symbolic_tensor(i) for i in update_ops) 732 if not context.executing_eagerly() or any_symbolic: 733 # If the current context is graph mode or any of the update ops are 734 # symbolic then the step update should be carried out under a graph 735 # context. (eager updates execute immediately) 736 with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access 737 with ops.control_dependencies([control_flow_ops.group(update_ops)]): 738 return self._iterations.assign_add(1, read_value=False) 739 740 return self._iterations.assign_add(1) 741 742 def get_gradients(self, loss, params): 743 """Returns gradients of `loss` with respect to `params`. 744 745 Should be used only in legacy v1 graph mode. 746 747 Args: 748 loss: Loss tensor. 749 params: List of variables. 750 751 Returns: 752 List of gradient tensors. 753 754 Raises: 755 ValueError: In case any gradient cannot be computed (e.g. if gradient 756 function not implemented). 757 """ 758 params = nest.flatten(params) 759 with backend.get_graph().as_default(), backend.name_scope(self._name + 760 "/gradients"): 761 grads = gradients.gradients(loss, params) 762 for grad, param in zip(grads, params): 763 if grad is None: 764 raise ValueError("Variable {} has `None` for gradient. " 765 "Please make sure that all of your ops have a " 766 "gradient defined (i.e. are differentiable). " 767 "Common ops without gradient: " 768 "K.argmax, K.round, K.eval.".format(param)) 769 return grads 770 771 def get_updates(self, loss, params): 772 grads = self.get_gradients(loss, params) 773 grads_and_vars = list(zip(grads, params)) 774 self._assert_valid_dtypes([ 775 v for g, v in grads_and_vars 776 if g is not None and v.dtype != dtypes.resource 777 ]) 778 return [self.apply_gradients(grads_and_vars)] 779 780 def _set_hyper(self, name, value): 781 """set hyper `name` to value. value can be callable, tensor, numeric.""" 782 if isinstance(value, trackable.Trackable): 783 self._track_trackable(value, name, overwrite=True) 784 if name not in self._hyper: 785 self._hyper[name] = value 786 else: 787 prev_value = self._hyper[name] 788 if (callable(prev_value) 789 or isinstance(prev_value, 790 (ops.Tensor, int, float, 791 learning_rate_schedule.LearningRateSchedule)) 792 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 793 self._hyper[name] = value 794 else: 795 backend.set_value(self._hyper[name], value) 796 797 def _get_hyper(self, name, dtype=None): 798 if not self._hypers_created: 799 self._create_hypers() 800 value = self._hyper[name] 801 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 802 return value 803 if callable(value): 804 value = value() 805 if dtype: 806 return math_ops.cast(value, dtype) 807 else: 808 return value 809 810 def _create_slots(self, var_list): 811 pass 812 813 def _create_all_weights(self, var_list): 814 """Creates all weights, including iterations, hyperparameters and slot vars. 815 816 This will add newly created variables to `optimizer.weights`. 817 818 New variables are only created when this method is called the first time, or 819 when called with different variables in the var_list. 820 821 Args: 822 var_list: list or tuple of `Variable` objects that will be minimized 823 using this optimizer. 824 """ 825 826 _ = self.iterations 827 self._create_hypers() 828 self._create_slots(var_list) 829 830 def __getattribute__(self, name): 831 """Overridden to support hyperparameter access.""" 832 try: 833 return super(OptimizerV2, self).__getattribute__(name) 834 except AttributeError as e: 835 # Needed to avoid infinite recursion with __setattr__. 836 if name == "_hyper": 837 raise e 838 # Backwards compatibility with Keras optimizers. 839 if name == "lr": 840 name = "learning_rate" 841 if name in self._hyper: 842 return self._get_hyper(name) 843 raise e 844 845 def __dir__(self): 846 result = set(super(OptimizerV2, self).__dir__()) 847 if "_hyper" in result: 848 result |= self._hyper.keys() 849 if "learning_rate" in self._hyper.keys(): 850 result.add("lr") 851 return list(result) 852 853 def __setattr__(self, name, value): 854 """Override setattr to support dynamic hyperparameter setting.""" 855 # Backwards compatibility with Keras optimizers. 856 if name == "lr": 857 name = "learning_rate" 858 if hasattr(self, "_hyper") and name in self._hyper: 859 self._set_hyper(name, value) 860 else: 861 super(OptimizerV2, self).__setattr__(name, value) 862 863 def get_slot_names(self): 864 """A list of names for this optimizer's slots.""" 865 return self._slot_names 866 867 def add_slot(self, var, slot_name, initializer="zeros", shape=None): 868 """Add a new slot variable for `var`. 869 870 A slot variable is an additional variable associated with `var` to train. 871 It is allocated and managed by optimizers, e.g. `Adam`. 872 873 Args: 874 var: a `Variable` object. 875 slot_name: name of the slot variable. 876 initializer: initializer of the slot variable 877 shape: (Optional) shape of the slot variable. If not set, it will default 878 to the shape of `var`. 879 880 Returns: 881 A slot variable. 882 """ 883 if slot_name not in self._slot_names: 884 self._slot_names.append(slot_name) 885 var_key = _var_key(var) 886 slot_dict = self._slots.setdefault(var_key, {}) 887 weight = slot_dict.get(slot_name, None) 888 if weight is None: 889 if isinstance(initializer, str) or callable(initializer): 890 initializer = initializers.get(initializer) 891 if isinstance( 892 initializer, 893 trackable.CheckpointInitialValueCallable) or (shape is not None): 894 slot_shape = shape 895 else: 896 slot_shape = var.shape 897 initial_value = functools.partial( 898 initializer, shape=slot_shape, dtype=var.dtype) 899 else: 900 initial_value = initializer 901 902 with self._distribution_strategy_scope(): 903 strategy = distribute_ctx.get_strategy() 904 if not strategy.extended.variable_created_in_scope(var): 905 raise ValueError( 906 "Trying to create optimizer slot variable under the scope for " 907 "tf.distribute.Strategy ({}), which is different from the scope " 908 "used for the original variable ({}). Make sure the slot " 909 "variables are created under the same strategy scope. This may " 910 "happen if you're restoring from a checkpoint outside the scope" 911 .format(strategy, var)) 912 913 with strategy.extended.colocate_vars_with(var): 914 weight = tf_variables.Variable( 915 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 916 dtype=var.dtype, 917 trainable=False, 918 initial_value=initial_value) 919 backend.track_variable(weight) 920 slot_dict[slot_name] = weight 921 self._restore_slot_variable( 922 slot_name=slot_name, variable=var, 923 slot_variable=weight) 924 self._weights.append(weight) 925 return weight 926 927 def get_slot(self, var, slot_name): 928 var_key = _var_key(var) 929 slot_dict = self._slots[var_key] 930 return slot_dict[slot_name] 931 932 def _prepare(self, var_list): 933 keys = set() 934 for var in var_list: 935 if isinstance(var, ds_values.DistributedValues): 936 var_devices = var._devices # pylint: disable=protected-access 937 else: 938 var_devices = [var.device] 939 var_dtype = var.dtype.base_dtype 940 for var_device in var_devices: 941 keys.add((var_device, var_dtype)) 942 943 apply_state = {} 944 for var_device, var_dtype in keys: 945 apply_state[(var_device, var_dtype)] = {} 946 with ops.device(var_device): 947 self._prepare_local(var_device, var_dtype, apply_state) 948 949 return apply_state 950 951 def _prepare_local(self, var_device, var_dtype, apply_state): 952 if "learning_rate" in self._hyper: 953 lr_t = array_ops.identity(self._decayed_lr(var_dtype)) 954 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 955 956 def _fallback_apply_state(self, var_device, var_dtype): 957 """Compatibility for subclasses that don't pass apply_state through.""" 958 apply_state = {(var_device, var_dtype): {}} 959 self._prepare_local(var_device, var_dtype, apply_state) 960 return apply_state[(var_device, var_dtype)] 961 962 def _create_hypers(self): 963 if self._hypers_created: 964 return 965 with self._distribution_strategy_scope(): 966 # Iterate hyper values deterministically. 967 for name, value in sorted(self._hyper.items()): 968 if isinstance(value, 969 (ops.Tensor, tf_variables.Variable)) or callable(value): 970 # The check for `callable` covers the usage when `value` is a 971 # `LearningRateSchedule`, in which case it does not need to create a 972 # variable. 973 continue 974 else: 975 self._hyper[name] = self.add_weight( 976 name, 977 shape=[], 978 trainable=False, 979 initializer=value, 980 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 981 self._hypers_created = True 982 983 @property 984 def iterations(self): 985 """Variable. The number of training steps this Optimizer has run.""" 986 if self._iterations is None: 987 with self._distribution_strategy_scope(): 988 self._iterations = self.add_weight( 989 "iter", 990 shape=[], 991 dtype=dtypes.int64, 992 trainable=False, 993 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 994 self._weights.append(self._iterations) 995 return self._iterations 996 997 @iterations.setter 998 def iterations(self, variable): 999 if self._iterations is not None: 1000 raise RuntimeError("Cannot set `iterations` to a new Variable after " 1001 "the Optimizer weights have been created") 1002 self._iterations = variable 1003 self._weights.append(self._iterations) 1004 1005 def _decayed_lr(self, var_dtype): 1006 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 1007 lr_t = self._get_hyper("learning_rate", var_dtype) 1008 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 1009 local_step = math_ops.cast(self.iterations, var_dtype) 1010 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 1011 if self._initial_decay > 0.: 1012 local_step = math_ops.cast(self.iterations, var_dtype) 1013 decay_t = math_ops.cast(self._initial_decay, var_dtype) 1014 lr_t = lr_t / (1. + decay_t * local_step) 1015 return lr_t 1016 1017 @abc.abstractmethod 1018 def get_config(self): 1019 """Returns the config of the optimizer. 1020 1021 An optimizer config is a Python dictionary (serializable) 1022 containing the configuration of an optimizer. 1023 The same optimizer can be reinstantiated later 1024 (without any saved state) from this configuration. 1025 1026 Returns: 1027 Python dictionary. 1028 """ 1029 config = {"name": self._name} 1030 if self.clipnorm is not None: 1031 config["clipnorm"] = self.clipnorm 1032 if self.clipvalue is not None: 1033 config["clipvalue"] = self.clipvalue 1034 if self.global_clipnorm is not None: 1035 config["global_clipnorm"] = self.global_clipnorm 1036 return config 1037 1038 @classmethod 1039 def from_config(cls, config, custom_objects=None): 1040 """Creates an optimizer from its config. 1041 1042 This method is the reverse of `get_config`, 1043 capable of instantiating the same optimizer from the config 1044 dictionary. 1045 1046 Args: 1047 config: A Python dictionary, typically the output of get_config. 1048 custom_objects: A Python dictionary mapping names to additional Python 1049 objects used to create this optimizer, such as a function used for a 1050 hyperparameter. 1051 1052 Returns: 1053 An optimizer instance. 1054 """ 1055 if "lr" in config: 1056 config["learning_rate"] = config.pop("lr") 1057 if "learning_rate" in config: 1058 if isinstance(config["learning_rate"], dict): 1059 config["learning_rate"] = learning_rate_schedule.deserialize( 1060 config["learning_rate"], custom_objects=custom_objects) 1061 return cls(**config) 1062 1063 def _serialize_hyperparameter(self, hyperparameter_name): 1064 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 1065 value = self._hyper[hyperparameter_name] 1066 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 1067 return learning_rate_schedule.serialize(value) 1068 if callable(value): 1069 return value() 1070 if tensor_util.is_tf_type(value): 1071 return backend.get_value(value) 1072 return value 1073 1074 def variables(self): 1075 """Returns variables of this Optimizer based on the order created.""" 1076 return self._weights 1077 1078 @property 1079 def weights(self): 1080 """Returns variables of this Optimizer based on the order created.""" 1081 return self._weights 1082 1083 def get_weights(self): 1084 """Returns the current weights of the optimizer. 1085 1086 The weights of an optimizer are its state (ie, variables). 1087 This function returns the weight values associated with this 1088 optimizer as a list of Numpy arrays. The first value is always the 1089 iterations count of the optimizer, followed by the optimizer's state 1090 variables in the order they were created. The returned list can in turn 1091 be used to load state into similarly parameterized optimizers. 1092 1093 For example, the RMSprop optimizer for this simple model returns a list of 1094 three values-- the iteration count, followed by the root-mean-square value 1095 of the kernel and bias of the single Dense layer: 1096 1097 >>> opt = tf.keras.optimizers.RMSprop() 1098 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1099 >>> m.compile(opt, loss='mse') 1100 >>> data = np.arange(100).reshape(5, 20) 1101 >>> labels = np.zeros(5) 1102 >>> results = m.fit(data, labels) # Training. 1103 >>> len(opt.get_weights()) 1104 3 1105 1106 Returns: 1107 Weights values as a list of numpy arrays. 1108 """ 1109 params = self.weights 1110 return backend.batch_get_value(params) 1111 1112 # TODO(tanzheny): Maybe share this logic with base_layer. 1113 def set_weights(self, weights): 1114 """Set the weights of the optimizer. 1115 1116 The weights of an optimizer are its state (ie, variables). 1117 This function takes the weight values associated with this 1118 optimizer as a list of Numpy arrays. The first value is always the 1119 iterations count of the optimizer, followed by the optimizer's state 1120 variables in the order they are created. The passed values are used to set 1121 the new state of the optimizer. 1122 1123 For example, the RMSprop optimizer for this simple model takes a list of 1124 three values-- the iteration count, followed by the root-mean-square value 1125 of the kernel and bias of the single Dense layer: 1126 1127 >>> opt = tf.keras.optimizers.RMSprop() 1128 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1129 >>> m.compile(opt, loss='mse') 1130 >>> data = np.arange(100).reshape(5, 20) 1131 >>> labels = np.zeros(5) 1132 >>> results = m.fit(data, labels) # Training. 1133 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 1134 >>> opt.set_weights(new_weights) 1135 >>> opt.iterations 1136 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 1137 1138 Args: 1139 weights: weight values as a list of numpy arrays. 1140 """ 1141 params = self.weights 1142 if len(params) != len(weights): 1143 raise ValueError( 1144 "You called `set_weights(weights)` on optimizer " + self._name + 1145 " with a weight list of length " + str(len(weights)) + 1146 ", but the optimizer was expecting " + str(len(params)) + 1147 " weights. Provided weights: " + str(weights)[:50] + "...") 1148 if not params: 1149 return 1150 weight_value_tuples = [] 1151 param_values = backend.batch_get_value(params) 1152 for pv, p, w in zip(param_values, params, weights): 1153 if pv.shape != w.shape: 1154 raise ValueError("Optimizer weight shape " + str(pv.shape) + 1155 " not compatible with " 1156 "provided weight shape " + str(w.shape)) 1157 weight_value_tuples.append((p, w)) 1158 backend.batch_set_value(weight_value_tuples) 1159 1160 def add_weight(self, 1161 name, 1162 shape, 1163 dtype=None, 1164 initializer="zeros", 1165 trainable=None, 1166 synchronization=tf_variables.VariableSynchronization.AUTO, 1167 aggregation=tf_variables.VariableAggregation.NONE): 1168 1169 if dtype is None: 1170 dtype = dtypes.float32 1171 if isinstance(initializer, str) or callable(initializer): 1172 initializer = initializers.get(initializer) 1173 1174 if synchronization == tf_variables.VariableSynchronization.ON_READ: 1175 if trainable: 1176 raise ValueError( 1177 "Synchronization value can be set to " 1178 "VariableSynchronization.ON_READ only for non-trainable variables. " 1179 "You have specified trainable=True and " 1180 "synchronization=VariableSynchronization.ON_READ.") 1181 else: 1182 # Set trainable to be false when variable is to be synced on read. 1183 trainable = False 1184 elif trainable is None: 1185 trainable = True 1186 1187 variable = self._add_variable_with_custom_getter( 1188 name=name, 1189 shape=shape, 1190 getter=base_layer_utils.make_variable, 1191 overwrite=True, 1192 initializer=initializer, 1193 dtype=dtype, 1194 trainable=trainable, 1195 use_resource=True, 1196 synchronization=synchronization, 1197 aggregation=aggregation) 1198 backend.track_variable(variable) 1199 1200 return variable 1201 1202 def _init_set_name(self, name, zero_based=True): 1203 if not name: 1204 self._name = backend.unique_object_name( 1205 generic_utils.to_snake_case(self.__class__.__name__), 1206 zero_based=zero_based) 1207 else: 1208 self._name = name 1209 1210 def _assert_valid_dtypes(self, tensors): 1211 """Asserts tensors are all valid types (see `_valid_dtypes`). 1212 1213 Args: 1214 tensors: Tensors to check. 1215 1216 Raises: 1217 ValueError: If any tensor is not a valid type. 1218 """ 1219 valid_dtypes = self._valid_dtypes() 1220 for t in tensors: 1221 dtype = t.dtype.base_dtype 1222 if dtype not in valid_dtypes: 1223 raise ValueError("Invalid type %r for %s, expected: %s." % 1224 (dtype, t.name, [v for v in valid_dtypes])) 1225 1226 def _valid_dtypes(self): 1227 """Valid types for loss, variables and gradients. 1228 1229 Subclasses should override to allow other float types. 1230 1231 Returns: 1232 Valid types for loss, variables and gradients. 1233 """ 1234 return _DEFAULT_VALID_DTYPES 1235 1236 def _call_if_callable(self, param): 1237 """Call the function if param is callable.""" 1238 return param() if callable(param) else param 1239 1240 def _resource_apply_dense(self, grad, handle, apply_state): 1241 """Add ops to apply dense gradients to the variable `handle`. 1242 1243 Args: 1244 grad: a `Tensor` representing the gradient. 1245 handle: a `Tensor` of dtype `resource` which points to the variable to be 1246 updated. 1247 apply_state: A dict which is used across multiple apply calls. 1248 1249 Returns: 1250 An `Operation` which updates the value of the variable. 1251 """ 1252 raise NotImplementedError("Must be implemented in subclasses.") 1253 1254 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 1255 **kwargs): 1256 """Add ops to apply sparse gradients to `handle`, with repeated indices. 1257 1258 Optimizers which override this method must deal with repeated indices. See 1259 the docstring of `_apply_sparse_duplicate_indices` for details. By default 1260 the correct behavior, to sum non-unique indices and their associated 1261 gradients, is enforced by first pre-processing `grad` and `indices` and 1262 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 1263 with duplicate indices may instead override this method to avoid the 1264 overhead of summing. 1265 1266 Args: 1267 grad: a `Tensor` representing the gradient for the affected indices. 1268 handle: a `Tensor` of dtype `resource` which points to the variable to be 1269 updated. 1270 indices: a `Tensor` of integral type representing the indices for which 1271 the gradient is nonzero. Indices may be repeated. 1272 **kwargs: May optionally contain `apply_state` 1273 1274 Returns: 1275 An `Operation` which updates the value of the variable. 1276 """ 1277 summed_grad, unique_indices = _deduplicate_indexed_slices( 1278 values=grad, indices=indices) 1279 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 1280 **kwargs) 1281 1282 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 1283 """Add ops to apply sparse gradients to the variable `handle`. 1284 1285 Similar to `_apply_sparse`, the `indices` argument to this method has been 1286 de-duplicated. Optimizers which deal correctly with non-unique indices may 1287 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 1288 overhead. 1289 1290 Args: 1291 grad: a `Tensor` representing the gradient for the affected indices. 1292 handle: a `Tensor` of dtype `resource` which points to the variable to be 1293 updated. 1294 indices: a `Tensor` of integral type representing the indices for which 1295 the gradient is nonzero. Indices are unique. 1296 apply_state: A dict which is used across multiple apply calls. 1297 1298 Returns: 1299 An `Operation` which updates the value of the variable. 1300 """ 1301 raise NotImplementedError("Must be implemented in subclasses.") 1302 1303 def _resource_scatter_add(self, x, i, v): 1304 with ops.control_dependencies([ 1305 gen_resource_variable_ops.ResourceScatterAdd( 1306 resource=x.handle, indices=i, updates=v) 1307 ]): 1308 return x.value() 1309 1310 def _resource_scatter_update(self, x, i, v): 1311 with ops.control_dependencies( 1312 [gen_resource_variable_ops.ResourceScatterUpdate( 1313 resource=x.handle, indices=i, updates=v)]): 1314 return x.value() 1315 1316 @property 1317 @layer_utils.cached_per_instance 1318 def _dense_apply_args(self): 1319 return tf_inspect.getfullargspec(self._resource_apply_dense).args 1320 1321 @property 1322 @layer_utils.cached_per_instance 1323 def _sparse_apply_args(self): 1324 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 1325 1326 # --------------- 1327 # For implementing the trackable interface 1328 # --------------- 1329 1330 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1331 """Restore a newly created slot variable's value.""" 1332 variable_key = _var_key(variable) 1333 deferred_restorations = self._deferred_slot_restorations.get( 1334 slot_name, {}).pop(variable_key, []) 1335 # Iterate over restores, highest restore UID first to minimize the number 1336 # of assignments. 1337 deferred_restorations.sort(key=lambda position: position.restore_uid, 1338 reverse=True) 1339 for checkpoint_position in deferred_restorations: 1340 checkpoint_position.restore(slot_variable) 1341 1342 def _create_or_restore_slot_variable( 1343 self, slot_variable_position, slot_name, variable): 1344 """Restore a slot variable's value, possibly creating it. 1345 1346 Called when a variable which has an associated slot variable is created or 1347 restored. When executing eagerly, we create the slot variable with a 1348 restoring initializer. 1349 1350 No new variables are created when graph building. Instead, 1351 _restore_slot_variable catches these after normal creation and adds restore 1352 ops to the graph. This method is nonetheless important when graph building 1353 for the case when a slot variable has already been created but `variable` 1354 has just been added to a dependency graph (causing us to realize that the 1355 slot variable needs to be restored). 1356 1357 Args: 1358 slot_variable_position: A `trackable._CheckpointPosition` object 1359 indicating the slot variable `Trackable` object to be restored. 1360 slot_name: The name of this `Optimizer`'s slot to restore into. 1361 variable: The variable object this slot is being created for. 1362 """ 1363 variable_key = _var_key(variable) 1364 slot_dict = self._slots.get(variable_key, {}) 1365 slot_variable = slot_dict.get(slot_name, None) 1366 if (slot_variable is None and context.executing_eagerly() and 1367 slot_variable_position.is_simple_variable() 1368 # Defer slot variable creation if there is an active variable creator 1369 # scope. Generally we'd like to eagerly create/restore slot variables 1370 # when possible, but this may mean that scopes intended to catch 1371 # `variable` also catch its eagerly created slot variable 1372 # unintentionally (specifically make_template would add a dependency on 1373 # a slot variable if not for this case). Deferring is mostly harmless 1374 # (aside from double initialization), and makes variable creator scopes 1375 # behave the same way they do when graph building. 1376 # 1377 # One notable case is with distribution strategy, which uses variable 1378 # creator scope but always desires the `variable` and the slot to use 1379 # the same scope, thus we can safely eagerly create/restore slot 1380 # variables. 1381 and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access 1382 self._distribution_strategy)): 1383 initializer = trackable.CheckpointInitialValueCallable( 1384 checkpoint_position=slot_variable_position) 1385 slot_variable = self.add_slot( 1386 var=variable, 1387 initializer=initializer, 1388 slot_name=slot_name, 1389 shape=slot_variable_position.value_shape()) 1390 # Slot variables are not owned by any one object (because we don't want to 1391 # save the slot variable if the optimizer is saved without the non-slot 1392 # variable, or if the non-slot variable is saved without the optimizer; 1393 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1394 # variable, variable)). So we don't _track_ slot variables anywhere, and 1395 # instead special-case this dependency and otherwise pretend it's a normal 1396 # graph. 1397 if slot_variable is not None: 1398 # If we've either made this slot variable, or if we've pulled out an 1399 # existing slot variable, we should restore it. 1400 slot_variable_position.restore(slot_variable) 1401 else: 1402 # We didn't make the slot variable. Defer restoring until it gets created 1403 # normally. We keep a list rather than the one with the highest restore 1404 # UID in case slot variables have their own dependencies, in which case 1405 # those could differ between restores. 1406 self._deferred_slot_restorations.setdefault( 1407 slot_name, {}).setdefault(variable_key, []).append( 1408 slot_variable_position) 1409 1410 @contextlib.contextmanager 1411 def _distribution_strategy_scope(self): 1412 """Returns the `tf.distribute.Strategy` this optimizer was created under.""" 1413 if self._distribution_strategy and not distribute_ctx.has_strategy(): 1414 with self._distribution_strategy.scope(): 1415 yield self._distribution_strategy.scope() 1416 else: 1417 yield 1418 1419 1420def _var_key(var): 1421 """Key for representing a primary variable, for looking up slots. 1422 1423 In graph mode the name is derived from the var shared name. 1424 In eager mode the name is derived from the var unique id. 1425 If distribution strategy exists, get the primary variable first. 1426 1427 Args: 1428 var: the variable. 1429 1430 Returns: 1431 the unique name of the variable. 1432 """ 1433 1434 # pylint: disable=protected-access 1435 # Get the distributed variable if it exists. 1436 if hasattr(var, "_distributed_container"): 1437 var = var._distributed_container() 1438 if var._in_graph_mode: 1439 return var._shared_name 1440 return var._unique_id 1441 1442 1443def _get_slot_key_from_var(var, slot_name): 1444 """Get the slot key for the variable: var_name/slot_name.""" 1445 1446 name = _var_key(var) 1447 return name + "/" + slot_name 1448 1449 1450class RestoredOptimizer(OptimizerV2): 1451 """A non-functional Optimizer implementation for checkpoint compatibility. 1452 1453 Holds slot variables and hyperparameters when an optimizer is restored from a 1454 SavedModel. These variables may be referenced in functions along with ops 1455 created by the original optimizer, but currently we do not support using the 1456 optimizer object iself (e.g. through `apply_gradients`). 1457 """ 1458 # TODO(allenl): Make the restored optimizer functional by tracing its apply 1459 # methods. 1460 1461 def __init__(self): 1462 super(RestoredOptimizer, self).__init__("RestoredOptimizer") 1463 self._hypers_created = True 1464 1465 def get_config(self): 1466 # TODO(allenl): Save and restore the Optimizer's config 1467 raise NotImplementedError( 1468 "Restoring functional Optimizers from SavedModels is not currently " 1469 "supported. Please file a feature request if this limitation bothers " 1470 "you.") 1471 1472revived_types.register_revived_type( 1473 "tf_deprecated_optimizer", 1474 lambda obj: isinstance(obj, OptimizerV2), 1475 versions=[revived_types.VersionedTypeRegistration( 1476 object_factory=lambda proto: RestoredOptimizer(), 1477 version=1, 1478 min_producer_version=1, 1479 min_consumer_version=1, 1480 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access 1481 )]) 1482