1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mid level API for TPU Embeddings.""" 16 17import functools 18from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union 19 20from absl import logging 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 24from tensorflow.python.distribute import device_util 25from tensorflow.python.distribute import distribute_utils 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.distribute import sharded_variable 28from tensorflow.python.distribute import tpu_strategy 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import device as tf_device 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import sparse_tensor 36from tensorflow.python.framework.tensor_shape import TensorShape 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import variables as tf_variables 41from tensorflow.python.ops.ragged import ragged_tensor 42from tensorflow.python.saved_model import registration 43from tensorflow.python.saved_model import save_context 44from tensorflow.python.tpu import tpu 45from tensorflow.python.tpu import tpu_embedding_v2_utils 46from tensorflow.python.tpu.ops import tpu_ops 47from tensorflow.python.trackable import autotrackable 48from tensorflow.python.trackable import base 49from tensorflow.python.types import internal as internal_types 50from tensorflow.python.util import compat 51from tensorflow.python.util import nest 52from tensorflow.python.util import tf_inspect 53from tensorflow.python.util.tf_export import tf_export 54 55 56_HOOK_KEY = "TPUEmbedding_saveable" 57_NAME_KEY = "_tpu_embedding_layer" 58 59 60class TPUEmbeddingVariable(sharded_variable.ShardedVariableMixin): 61 """A ShardedVariable class for TPU.""" 62 63 @property 64 def _in_graph_mode(self): 65 return self.variables[0]._in_graph_mode # pylint: disable=protected-access 66 67 68def _add_key_attr(op, name): 69 op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access 70 71 72@tf_export("tpu.experimental.embedding.TPUEmbedding") 73class TPUEmbedding(autotrackable.AutoTrackable): 74 """The TPUEmbedding mid level API. 75 76 NOTE: When instantiated under a TPUStrategy, this class can only be created 77 once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to 78 re-initialize the embedding engine you must re-initialize the tpu as well. 79 Doing this will clear any variables from TPU, so ensure you have checkpointed 80 before you do this. If a further instances of the class are needed, 81 set the `initialize_tpu_embedding` argument to `False`. 82 83 This class can be used to support training large embeddings on TPU. When 84 creating an instance of this class, you must specify the complete set of 85 tables and features you expect to lookup in those tables. See the 86 documentation of `tf.tpu.experimental.embedding.TableConfig` and 87 `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete 88 set of options. We will cover the basic usage here. 89 90 NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, 91 allowing different features to share the same table: 92 93 ```python 94 table_config_one = tf.tpu.experimental.embedding.TableConfig( 95 vocabulary_size=..., 96 dim=...) 97 table_config_two = tf.tpu.experimental.embedding.TableConfig( 98 vocabulary_size=..., 99 dim=...) 100 feature_config = { 101 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( 102 table=table_config_one), 103 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( 104 table=table_config_one), 105 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( 106 table=table_config_two)} 107 ``` 108 109 There are two modes under which the `TPUEmbedding` class can used. This 110 depends on if the class was created under a `TPUStrategy` scope or not. 111 112 Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and 113 `apply_gradients`. We will show examples below of how to use these to train 114 and evaluate your model. Under CPU, we only access to the `embedding_tables` 115 property which allow access to the embedding tables so that you can use them 116 to run model evaluation/prediction on CPU. 117 118 First lets look at the `TPUStrategy` mode. Initial setup looks like: 119 120 ```python 121 strategy = tf.distribute.TPUStrategy(...) 122 with strategy.scope(): 123 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 124 feature_config=feature_config, 125 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 126 ``` 127 128 When creating a distributed dataset that is to be passed to the enqueue 129 operation a special input option must be specified: 130 131 ```python 132 distributed_dataset = ( 133 strategy.distribute_datasets_from_function( 134 dataset_fn=..., 135 options=tf.distribute.InputOptions( 136 experimental_fetch_to_device=False)) 137 dataset_iterator = iter(distributed_dataset) 138 ``` 139 140 Different feature inputs can have different shapes. For dense and sparse 141 tensor, rank 2 and above is supported. For ragged tensor, although only rank 2 142 is supported, you can specify the output shape to be rank 2 and above. The 143 output shape specified in the FeatureConfig has the first priority. The input 144 shape passed in build method has second priority and the input shapes 145 auto detected from input feature has the lowest priority. The latter two will 146 be converted to output shapes by omitting the last dimension. If the lower 147 priority one has output shapes which don't match the former one. A ValueError 148 will be raised. Only when the former one has undefined output shapes, the 149 latter one can override. 150 151 NOTE: All batches passed to the layer can have different input shapes. But 152 these input shapes need to match with the output shapes set by either 153 `FeatureConfig` or build method except for ragged tensor. Only 2D 154 ragged tensor with output shape set to higher dimensions is allowed as 155 long as the total number of elements matches. All subsequent calls must have 156 the same input shapes. In the event that the input shapes cannot be 157 automatically determined by the enqueue method, you must call 158 the build method with the input shapes or provide output shapes in the 159 `FeatureConfig` to initialize the layer. 160 161 To use this API on TPU you should use a custom training loop. Below is an 162 example of a training and evaluation step: 163 164 ```python 165 @tf.function 166 def training_step(dataset_iterator, num_steps): 167 def tpu_step(tpu_features): 168 with tf.GradientTape() as tape: 169 activations = embedding.dequeue() 170 tape.watch(activations) 171 model_output = model(activations) 172 loss = ... # some function of labels and model_output 173 174 embedding_gradients = tape.gradient(loss, activations) 175 embedding.apply_gradients(embedding_gradients) 176 # Insert your model gradient and optimizer application here 177 178 for _ in tf.range(num_steps): 179 embedding_features, tpu_features = next(dataset_iterator) 180 embedding.enqueue(embedding_features, training=True) 181 strategy.run(tpu_step, args=(tpu_features, )) 182 183 @tf.function 184 def evalution_step(dataset_iterator, num_steps): 185 def tpu_step(tpu_features): 186 activations = embedding.dequeue() 187 model_output = model(activations) 188 # Insert your evaluation code here. 189 190 for _ in tf.range(num_steps): 191 embedding_features, tpu_features = next(dataset_iterator) 192 embedding.enqueue(embedding_features, training=False) 193 strategy.run(tpu_step, args=(tpu_features, )) 194 ``` 195 196 NOTE: The calls to `enqueue` have `training` set to `True` when 197 `embedding.apply_gradients` is used and set to `False` when 198 `embedding.apply_gradients` is not present in the function. If you don't 199 follow this pattern you may cause an error to be raised or the tpu may 200 deadlock. 201 202 In the above examples, we assume that the user has a dataset which returns 203 a tuple where the first element of the tuple matches the structure of what 204 was passed as the `feature_config` argument to the object initializer. Also we 205 utilize `tf.range` to get a `tf.while_loop` in order to increase performance. 206 207 When checkpointing your model, you should include your 208 `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a 209 trackable object and saving it will save the embedding tables and their 210 optimizer slot variables: 211 212 ```python 213 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 214 checkpoint.save(...) 215 ``` 216 217 On CPU, only the `embedding_table` property is usable. This will allow you to 218 restore a checkpoint to the object and have access to the table variables: 219 220 ```python 221 model = model_fn(...) 222 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 223 feature_config=feature_config, 224 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 225 checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) 226 checkpoint.restore(...) 227 228 tables = embedding.embedding_tables 229 ``` 230 231 You can now use table in functions like `tf.nn.embedding_lookup` to perform 232 your embedding lookup and pass to your model. 233 234 """ 235 236 def __init__( 237 self, 238 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 239 optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access 240 pipeline_execution_with_tensor_core: bool = False): 241 """Creates the TPUEmbedding mid level API object. 242 243 ```python 244 strategy = tf.distribute.TPUStrategy(...) 245 with strategy.scope(): 246 embedding = tf.tpu.experimental.embedding.TPUEmbedding( 247 feature_config=tf.tpu.experimental.embedding.FeatureConfig( 248 table=tf.tpu.experimental.embedding.TableConfig( 249 dim=..., 250 vocabulary_size=...))) 251 ``` 252 253 Args: 254 feature_config: A nested structure of 255 `tf.tpu.experimental.embedding.FeatureConfig` configs. 256 optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, 257 `tf.tpu.experimental.embedding.Adagrad` or 258 `tf.tpu.experimental.embedding.Adam`. When not created under 259 TPUStrategy may be set to None to avoid the creation of the optimizer 260 slot variables, useful for optimizing memory consumption when exporting 261 the model for serving where slot variables aren't needed. 262 pipeline_execution_with_tensor_core: If True, the TPU embedding 263 computations will overlap with the TensorCore computations (and hence 264 will be one step old). Set to True for improved performance. 265 266 Raises: 267 ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, 268 Adam or Adagrad) or None when created under a TPUStrategy. 269 """ 270 self._strategy = distribution_strategy_context.get_strategy() 271 self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, 272 tpu_strategy.TPUStrategyV2)) 273 self._pipeline_execution_with_tensor_core = ( 274 pipeline_execution_with_tensor_core) 275 276 self._feature_config = feature_config 277 self._output_shapes = [] 278 for feature in nest.flatten(feature_config): 279 self._output_shapes.append(feature.output_shape) 280 281 # The TPU embedding ops are slightly inconsistent with how they refer to 282 # tables: 283 # * The enqueue op takes a parallel list of tensors for input, one of those 284 # is the table id for the feature which matches the integer index of the 285 # table in the proto created by _create_config_proto(). 286 # * The recv_tpu_embedding_activations op emits lookups per table in the 287 # order from the config proto. 288 # * The send_tpu_embedding_gradients expects input tensors to be per table 289 # in the same order as the config proto. 290 # * Per optimizer load and retrieve ops are specified per table and take the 291 # table name rather than the table id. 292 # Thus we must fix a common order to tables and ensure they have unique 293 # names. 294 295 # Set table order here to the order of the first occurence of the table in a 296 # feature provided by the user. The order of this struct must be fixed 297 # to provide the user with deterministic behavior over multiple 298 # instantiations. 299 self._table_config = [] 300 for feature in nest.flatten(feature_config): 301 if feature.table not in self._table_config: 302 self._table_config.append(feature.table) 303 304 # Ensure tables have unique names. Also error check the optimizer as we 305 # specifically don't do that in the TableConfig class to allow high level 306 # APIs that are built on this to use strings/other classes to represent 307 # optimizers (before they are passed to this class). 308 table_names = [] 309 for i, table in enumerate(self._table_config): 310 if table.optimizer is None: 311 # TODO(bfontain) Should we allow some sort of optimizer merging here? 312 table.optimizer = optimizer 313 if ((table.optimizer is not None or self._using_tpu) and 314 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 315 raise ValueError("{} is an unsupported optimizer class. Please pass an " 316 "instance of one of the optimizer classes under " 317 "tf.tpu.experimental.embedding.".format( 318 type(table.optimizer))) 319 if table.name is None: 320 table.name = "table_{}".format(i) 321 if table.name in table_names: 322 raise ValueError("Tables must have a unique name. " 323 f"Multiple tables with name {table.name} found.") 324 table_names.append(table.name) 325 326 if self._using_tpu: 327 # Extract a list of callable learning rates also in fixed order. Each 328 # table in the confix proto will get a index into this list and we will 329 # pass this list in the same order after evaluation to the 330 # send_tpu_embedding_gradients op. 331 self._dynamic_learning_rates = list({ 332 table.optimizer.learning_rate for table in self._table_config if 333 callable(table.optimizer.learning_rate)}) 334 335 # We need to list of host devices for the load/retrieve operations. 336 self._hosts = get_list_of_hosts(self._strategy) 337 338 self._built = False 339 self._verify_output_shapes_on_enqueue = True 340 341 def build(self, per_replica_input_shapes=None, per_replica_batch_size=None): # pylint:disable=g-bare-generic 342 """Create the underlying variables and initializes the TPU for embeddings. 343 344 This method creates the underlying variables (including slot variables). If 345 created under a TPUStrategy, this will also initialize the TPU for 346 embeddings. 347 348 This function will automatically get called by enqueue, which will try to 349 determine your output shapes. If this fails, you must manually 350 call this method before you call enqueue. 351 352 Args: 353 per_replica_input_shapes: A nested structure of The per replica input 354 shapes that matches the structure of the feature config. The input 355 shapes should be the same as the input shape of the feature (except for 356 ragged tensor) Note that it is fixed and the same per replica input 357 shapes must be used for both training and evaluation. If you want to 358 calculate this from the global input shapes, you can use 359 `num_replicas_in_sync` property of your strategy object. May be set to 360 None if not created under a TPUStrategy. 361 per_replica_batch_size: (Deprecated) The per replica batch size that you 362 intend to use. Note that is fixed and the same batch size must be used 363 for both training and evaluation. If you want to calculate this from the 364 global batch size, you can use `num_replicas_in_sync` property of your 365 strategy object. May be set to None if not created under a TPUStrategy. 366 367 Raises: 368 ValueError: If per_replica_input_shapes is inconsistent with the output 369 shapes stored in the feature config or the output shapes get from the 370 input shapes are not fully defined. 371 RuntimeError: If tpu embedding is already initialized on TPU. 372 """ 373 if self._built: 374 return 375 376 if self._using_tpu: 377 # If the tpu embedding is already initialized on TPU, raise runtime error. 378 # Below logic is not added in `initialize_system_for_tpu_embedding` 379 # because doing exception control flow in graph mode is difficult. 380 if tpu_ops.is_tpu_embedding_initialized(): 381 raise RuntimeError( 382 "TPU is already initialized for embeddings. This may be caused by " 383 "using multiple TPUEmbedding instances in a TPU scope which is " 384 "unsupported") 385 self._get_and_update_output_shapes_from_input(per_replica_input_shapes, 386 per_replica_batch_size) 387 388 self._config_proto = self._create_config_proto() 389 390 logging.info("Initializing TPU Embedding engine.") 391 tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) 392 393 @def_function.function 394 def load_config(): 395 tpu.initialize_system_for_tpu_embedding(self._config_proto) 396 397 load_config() 398 logging.info("Done initializing TPU Embedding engine.") 399 400 # Create and load variables and slot variables into the TPU. 401 # Note that this is a dict of dicts. Keys to the first dict are table names. 402 # We would prefer to use TableConfigs, but then these variables won't be 403 # properly tracked by the tracking API. 404 self._variables = self._create_variables_and_slots() 405 406 self._built = True 407 408 # This is internally conditioned self._built and self._using_tpu 409 self._load_variables() 410 411 def _maybe_build(self, 412 output_shapes: Optional[Union[List[int], Iterable]] = None): # pylint:disable=g-bare-generic 413 if not self._built: 414 # This can be called while tracing a function, so we wrap the 415 # initialization code with init_scope so it runs eagerly, this means that 416 # it will not be included the function graph generated by tracing so that 417 # we can be sure that we only initialize the TPU for embeddings exactly 418 # once. 419 with ops.init_scope(): 420 self.build(output_shapes) 421 422 def _get_and_update_output_shapes_from_input( 423 self, 424 per_replica_input_shapes: Optional[List[TensorShape]] = None, 425 per_replica_batch_size: Optional[int] = None): 426 """Get and update the per replica output shapes from the input.""" 427 per_replica_output_shapes = None 428 if per_replica_batch_size and per_replica_input_shapes is None: 429 logging.warning( 430 "per_replica_batch_size argument will be deprecated, please specify " 431 "all the input shapes using per_replica_input_shapes argument.") 432 per_replica_output_shapes = self._get_output_shapes_from_batch_size( 433 per_replica_batch_size) 434 435 # Update the input shapes if provided. 436 if per_replica_input_shapes is not None: 437 if isinstance(per_replica_input_shapes, int): 438 logging.warning( 439 "Passing batch size to per_replica_input_shapes argument will be" 440 " deprecated, please specify all the input shapes using" 441 " per_replica_input_shapes argument.") 442 per_replica_output_shapes = self._get_output_shapes_from_batch_size( 443 per_replica_input_shapes) 444 else: 445 nest.assert_same_structure( 446 nest.flatten(per_replica_input_shapes), 447 nest.flatten(self._feature_config)) 448 449 # Convert the nested structure to list. 450 per_replica_input_shapes = nest.flatten(per_replica_input_shapes) 451 452 per_replica_output_shapes = self._get_output_shapes_from_input_shapes( 453 per_replica_input_shapes) 454 455 if per_replica_output_shapes is not None: 456 457 # Check the output shapes with existing output shapes setting. 458 self._check_output_shapes(per_replica_output_shapes) 459 460 # Update the output shapes with existing output shapes setting. 461 # This is necessary Because the output shapes might be missing from 462 # the feature config, the usr can set it: 463 # 1. calling the build method 464 # 2. output shapes auto detected when calling the dequeue method for 465 # for the first time. The dequeue method will call build method 466 # with the output shapes. 467 # Either these two situations will lead to an update to the existing 468 # output shapes. 469 self._update_output_shapes(per_replica_output_shapes) 470 471 # Check if the output shapes are fully defined. This is required in order 472 # to set them in the feature descriptor field of the tpu embedding config 473 # proto. 474 self._check_output_shapes_fully_defined() 475 476 def _get_output_shapes_from_input_shapes( 477 self, input_shapes: List[TensorShape]) -> List[TensorShape]: 478 """Get output shapes from the flattened input shapes list.""" 479 output_shapes = [] 480 for input_shape, feature in zip(input_shapes, 481 nest.flatten(self._feature_config)): 482 if input_shape.rank is None or input_shape.rank < 1: 483 raise ValueError( 484 "Received input tensor of shape {}. Rank must be 1 and above" 485 .format(input_shape)) 486 # Update the input shape with the max sequence length. Only update when 487 # 1. Input feature is 2D ragged or sparse tensor. 488 # 2. Output shape is not set in the feature config and the max sequence 489 # length is set. 490 if (len(input_shape) == 2 and input_shape[-1] != 1 and 491 not feature.output_shape and feature.max_sequence_length > 0): 492 input_shape_list = input_shape.as_list() 493 input_shape_list.insert( 494 len(input_shape_list) - 1, feature.max_sequence_length) 495 input_shape = TensorShape(input_shape_list) 496 if input_shape.rank == 1: 497 output_shapes.append(input_shape) 498 else: 499 output_shapes.append(input_shape[:-1]) 500 return output_shapes 501 502 @property 503 def embedding_tables( 504 self 505 ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 506 """Returns a dict of embedding tables, keyed by `TableConfig`. 507 508 This property only works when the `TPUEmbedding` object is created under a 509 non-TPU strategy. This is intended to be used to for CPU based lookup when 510 creating a serving checkpoint. 511 512 Returns: 513 A dict of embedding tables, keyed by `TableConfig`. 514 515 Raises: 516 RuntimeError: If object was created under a `TPUStrategy`. 517 """ 518 # We don't support returning tables on TPU due to their sharded nature and 519 # the fact that when using a TPUStrategy: 520 # 1. Variables are stale and are only updated when a checkpoint is made. 521 # 2. Updating the variables won't affect the actual tables on the TPU. 522 if self._using_tpu: 523 if save_context.in_save_context(): 524 return {table: self._variables[table.name]["parameters"].variables[0] 525 for table in self._table_config} 526 raise RuntimeError("Unable to retrieve embedding tables when using a TPU " 527 "strategy. If you need access, save your model, " 528 "create this object under a CPU strategy and restore.") 529 530 self._maybe_build(None) 531 532 # Only return the tables and not the slot variables. On CPU this are honest 533 # tf.Variables. 534 return {table: self._variables[table.name]["parameters"] 535 for table in self._table_config} 536 537 def _create_config_proto( 538 self 539 ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration: 540 """Creates the TPUEmbeddingConfiguration proto. 541 542 This proto is used to initialize the TPU embedding engine. 543 544 Returns: 545 A TPUEmbeddingConfiguration proto. 546 """ 547 548 config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() 549 550 # Map each callable dynamic learning rate to its in index in the list. 551 # The learning rate index is the index of the dynamic learning rate for this 552 # table (if it exists) in the list we created at initialization. We don't 553 # simply create one learning rate index per table as this has extremely bad 554 # performance characteristics. The more separate optimization configurations 555 # we have, the worse the performance will be. 556 learning_rate_index = {r: i for i, r in enumerate( 557 self._dynamic_learning_rates)} 558 559 for table in self._table_config: 560 table_descriptor = config_proto.table_descriptor.add() 561 table_descriptor.name = table.name 562 563 # For small tables, we pad to the number of hosts so that at least one 564 # id will be assigned to each host. 565 table_descriptor.vocabulary_size = max(table.vocabulary_size, 566 self._strategy.extended.num_hosts) 567 table_descriptor.dimension = table.dim 568 569 parameters = table_descriptor.optimization_parameters 570 571 # We handle the learning rate separately here and don't allow the 572 # optimization class to handle this, as it doesn't know about dynamic 573 # rates. 574 if callable(table.optimizer.learning_rate): 575 parameters.learning_rate.dynamic.tag = ( 576 learning_rate_index[table.optimizer.learning_rate]) 577 else: 578 parameters.learning_rate.constant = table.optimizer.learning_rate 579 580 # Use optimizer to handle the rest of the parameters. 581 table.optimizer._set_optimization_parameters(parameters) # pylint: disable=protected-access 582 583 table_to_id = {table: i for i, table in enumerate(self._table_config)} 584 585 # Set feature descriptor field in the config proto. 586 for feature, output_shape in zip( 587 nest.flatten(self._feature_config), self._output_shapes): 588 feature_descriptor = config_proto.feature_descriptor.add() 589 590 if feature.name: 591 feature_descriptor.name = feature.name 592 593 feature_descriptor.table_id = table_to_id[feature.table] 594 # The input shape of the feature is the actual shape of the input tensor 595 # except the last dimension because the last dimension will always be 596 # reduced. 597 feature_descriptor.input_shape.extend(output_shape.as_list()) 598 599 # Always set mode to training, we override the mode during enqueue. 600 config_proto.mode = ( 601 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) 602 603 config_proto.num_hosts = self._strategy.extended.num_hosts 604 config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync 605 606 # TODO(bfontain): Allow users to pick MOD for the host sharding. 607 config_proto.sharding_strategy = ( 608 tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) 609 config_proto.pipeline_execution_with_tensor_core = ( 610 self._pipeline_execution_with_tensor_core) 611 612 return config_proto 613 614 def apply_gradients(self, gradients, name: Optional[Text] = None): 615 """Applies the gradient update to the embedding tables. 616 617 If a gradient of `None` is passed in any position of the nested structure, 618 then an gradient update with a zero gradient is applied for that feature. 619 For optimizers like SGD or Adagrad, this is the same as applying no update 620 at all. For lazy Adam and other sparsely applied optimizers with decay, 621 ensure you understand the effect of applying a zero gradient. 622 623 ```python 624 strategy = tf.distribute.TPUStrategy(...) 625 with strategy.scope(): 626 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 627 628 distributed_dataset = ( 629 strategy.distribute_datasets_from_function( 630 dataset_fn=..., 631 options=tf.distribute.InputOptions( 632 experimental_fetch_to_device=False)) 633 dataset_iterator = iter(distributed_dataset) 634 635 @tf.function 636 def training_step(): 637 def tpu_step(tpu_features): 638 with tf.GradientTape() as tape: 639 activations = embedding.dequeue() 640 tape.watch(activations) 641 642 loss = ... # some computation involving activations 643 644 embedding_gradients = tape.gradient(loss, activations) 645 embedding.apply_gradients(embedding_gradients) 646 647 embedding_features, tpu_features = next(dataset_iterator) 648 embedding.enqueue(embedding_features, training=True) 649 strategy.run(tpu_step, args=(tpu_features, )) 650 651 training_step() 652 ``` 653 654 Args: 655 gradients: A nested structure of gradients, with structure matching the 656 `feature_config` passed to this object. 657 name: A name for the underlying op. 658 659 Raises: 660 RuntimeError: If called when object wasn't created under a `TPUStrategy` 661 or if not built (either by manually calling build or calling enqueue). 662 ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a 663 `tf.Tensor` of the incorrect shape is passed in. Also if 664 the size of any sequence in `gradients` does not match corresponding 665 sequence in `feature_config`. 666 TypeError: If the type of any sequence in `gradients` does not match 667 corresponding sequence in `feature_config`. 668 """ 669 if not self._using_tpu: 670 raise RuntimeError("apply_gradients is not valid when TPUEmbedding " 671 "object is not created under a TPUStrategy.") 672 673 if not self._built: 674 raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding " 675 "object. Please either call enqueue first or manually " 676 "call the build method.") 677 678 nest.assert_same_structure(self._feature_config, gradients) 679 updated_gradients = [] 680 for (path, gradient), feature, output_shape in zip( 681 nest.flatten_with_joined_string_paths(gradients), 682 nest.flatten(self._feature_config), self._output_shapes): 683 full_output_shape = list(output_shape) + [feature.table.dim] 684 if gradient is not None and not isinstance(gradient, ops.Tensor): 685 raise ValueError( 686 f"found non-tensor type: {type(gradient)} at path {path}.") 687 if gradient is not None: 688 if gradient.shape != full_output_shape: 689 raise ValueError("Found gradient of shape {} at path {}. Expected " 690 "shape {}.".format(gradient.shape, path, 691 full_output_shape)) 692 else: 693 # No gradient for this feature, since we must give a gradient for all 694 # features, pass in a zero tensor here. Note that this is not correct 695 # for all optimizers. 696 logging.warning( 697 "No gradient passed for feature %s, sending zero " 698 "gradient. This may not be correct behavior for certain " 699 "optimizers like Adam.", path) 700 gradient = array_ops.zeros(full_output_shape, dtype=dtypes.float32) 701 # Some gradients can be passed with op which shape is not correctly set. 702 # This ensures that the shape of the gradient is correctly set. 703 updated_gradients.append( 704 array_ops.reshape(gradient, shape=gradient.shape)) 705 op = tpu_ops.send_tpu_embedding_gradients( 706 inputs=updated_gradients, 707 learning_rates=[ 708 math_ops.cast(fn(), dtype=dtypes.float32) 709 for fn in self._dynamic_learning_rates 710 ], 711 config=self._config_proto.SerializeToString()) 712 713 # Apply the name tag to the op. 714 if name is not None: 715 _add_key_attr(op, name) 716 717 def dequeue(self, name: Optional[Text] = None): 718 """Get the embedding results. 719 720 Returns a nested structure of `tf.Tensor` objects, matching the structure of 721 the `feature_config` argument to the `TPUEmbedding` class. The output shape 722 of the tensors is `(*output_shape, dim)`, `dim` is the dimension of the 723 corresponding `TableConfig`. For output_shape, there are three places where 724 it can be set. 725 1. FeatureConfig provided in the __init__ function. 726 2. Per_replica_output_shapes by directly calling the build method 727 after initializing the tpu embedding class. 728 3. Auto detected from the shapes of the input feature. 729 The priority of these places is the exact same order. 730 731 ```python 732 strategy = tf.distribute.TPUStrategy(...) 733 with strategy.scope(): 734 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 735 736 distributed_dataset = ( 737 strategy.distribute_datasets_from_function( 738 dataset_fn=..., 739 options=tf.distribute.InputOptions( 740 experimental_fetch_to_device=False)) 741 dataset_iterator = iter(distributed_dataset) 742 743 @tf.function 744 def training_step(): 745 def tpu_step(tpu_features): 746 with tf.GradientTape() as tape: 747 activations = embedding.dequeue() 748 tape.watch(activations) 749 750 loss = ... # some computation involving activations 751 752 embedding_gradients = tape.gradient(loss, activations) 753 embedding.apply_gradients(embedding_gradients) 754 755 embedding_features, tpu_features = next(dataset_iterator) 756 embedding.enqueue(embedding_features, training=True) 757 strategy.run(tpu_step, args=(tpu_features, )) 758 759 training_step() 760 ``` 761 762 Args: 763 name: A name for the underlying op. 764 765 Returns: 766 A nested structure of tensors, with the same structure as `feature_config` 767 passed to this instance of the `TPUEmbedding` object. 768 769 Raises: 770 RuntimeError: If called when object wasn't created under a `TPUStrategy` 771 or if not built (either by manually calling build or calling enqueue). 772 """ 773 if not self._using_tpu: 774 raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " 775 "created under a TPUStrategy.") 776 777 if not self._built: 778 raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. " 779 "Please either call enqueue first or manually call " 780 "the build method.") 781 782 # The activations returned by this op are per feature. 783 activations = tpu_ops.recv_tpu_embedding_activations( 784 num_outputs=len(self._config_proto.feature_descriptor), 785 config=self._config_proto.SerializeToString()) 786 787 # Apply the name tag to the op. 788 if name is not None: 789 _add_key_attr(activations[0].op, name) 790 791 # Pack the list back into the same nested structure as the features. 792 return nest.pack_sequence_as(self._feature_config, activations) 793 794 def _create_variables_and_slots( 795 self 796 ) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 797 """Create variables for TPU embeddings. 798 799 Note under TPUStrategy this will ensure that all creations happen within a 800 variable creation scope of the sharded variable creator. 801 802 Returns: 803 A dict of dicts. The outer dict is keyed by the table names and the inner 804 dicts are keyed by 'parameters' and the slot variable names. 805 """ 806 807 def create_variables(table): 808 """Create all variables.""" 809 variable_shape = (table.vocabulary_size, table.dim) 810 811 def getter(name, shape, dtype, initializer, trainable): 812 del shape 813 # _add_variable_with_custom_getter clears the shape sometimes, so we 814 # take the global shape from outside the getter. 815 initial_value = functools.partial(initializer, variable_shape, 816 dtype=dtype) 817 return tf_variables.Variable( 818 name=name, 819 initial_value=initial_value, 820 shape=variable_shape, 821 dtype=dtype, 822 trainable=trainable) 823 824 def variable_creator(name, initializer, trainable=True): 825 # use add_variable_with_custom_getter here so that we take advantage of 826 # the checkpoint loading to allow restore before the variables get 827 # created which avoids double initialization. 828 return self._add_variable_with_custom_getter( 829 name=name, 830 initializer=initializer, 831 shape=variable_shape, 832 dtype=dtypes.float32, 833 getter=getter, 834 trainable=trainable) 835 836 parameters = variable_creator(table.name, table.initializer, 837 trainable=not self._using_tpu) 838 839 def slot_creator(name, initializer): 840 return variable_creator(table.name + "/" + name, 841 initializer, 842 False) 843 844 if table.optimizer is not None: 845 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 846 else: 847 slot_vars = {} 848 slot_vars["parameters"] = parameters 849 return slot_vars 850 851 # Store tables based on name rather than TableConfig as we can't track 852 # through dicts with non-string keys, i.e. we won't be able to save. 853 variables = {} 854 for table in self._table_config: 855 if not self._using_tpu: 856 variables[table.name] = create_variables(table) 857 else: 858 with variable_scope.variable_creator_scope( 859 make_sharded_variable_creator(self._hosts)): 860 variables[table.name] = create_variables(table) 861 862 return variables 863 864 def _load_variables(self): 865 # Only load the variables if we are: 866 # 1) Using TPU 867 # 2) Variables are created 868 # 3) Not in save context (except if running eagerly) 869 if self._using_tpu and self._built and not ( 870 not context.executing_eagerly() and save_context.in_save_context()): 871 _load_variables_impl(self._config_proto.SerializeToString(), 872 self._hosts, 873 self._variables, 874 self._table_config) 875 876 def _retrieve_variables(self): 877 # Only retrieve the variables if we are: 878 # 1) Using TPU 879 # 2) Variables are created 880 # 3) Not in save context (except if running eagerly) 881 if self._using_tpu and self._built and not ( 882 not context.executing_eagerly() and save_context.in_save_context()): 883 _retrieve_variables_impl(self._config_proto.SerializeToString(), 884 self._hosts, 885 self._variables, 886 self._table_config) 887 888 # Some helper functions for the below enqueue function. 889 def _add_data_for_tensor(self, tensor, weight, indices, values, weights, 890 int_zeros, float_zeros, path): 891 if weight is not None: 892 raise ValueError( 893 "Weight specified for dense input {}, which is not allowed. " 894 "Weight will always be 1 in this case.".format(path)) 895 # For tensors, there are no indices and no weights. 896 indices.append(int_zeros) 897 values.append(math_ops.cast(array_ops.reshape(tensor, [-1]), dtypes.int64)) 898 weights.append(float_zeros) 899 900 def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, 901 weights, int_zeros, float_zeros, path, 902 feature): 903 sample_indices = math_ops.cast(tensor.indices, dtypes.int32) 904 if tensor.shape.rank == 2: 905 if not feature.output_shape and feature.max_sequence_length > 0: 906 # Add one dimension to the last axis. 907 sample_indices = array_ops.pad( 908 sample_indices, paddings=[[0, 0], [0, 1]]) 909 indices.append(sample_indices) 910 values.append(math_ops.cast(tensor.values, dtypes.int64)) 911 # If we have weights they must be a SparseTensor. 912 if weight is not None: 913 if not isinstance(weight, sparse_tensor.SparseTensor): 914 raise ValueError("Weight for {} is type {} which does not match " 915 "type input which is SparseTensor.".format( 916 path, type(weight))) 917 weights.append(math_ops.cast(weight.values, dtypes.float32)) 918 else: 919 weights.append(float_zeros) 920 921 def _add_data_for_ragged_tensor(self, tensor, weight, row_splits, values, 922 weights, int_zeros, float_zeros, path, 923 feature): 924 row_splits.append(math_ops.cast(tensor.row_splits, dtypes.int32)) 925 values.append(math_ops.cast(tensor.values, dtypes.int64)) 926 # If we have weights they must be a RaggedTensor. 927 if weight is not None: 928 if not isinstance(weight, ragged_tensor.RaggedTensor): 929 raise ValueError("Weight for {} is type {} which does not match " 930 "type input which is RaggedTensor.".format( 931 path, type(weight))) 932 weights.append(math_ops.cast(weight.values, dtypes.float32)) 933 else: 934 weights.append(float_zeros) 935 936 def _generate_enqueue_op( 937 self, 938 flat_inputs: List[internal_types.NativeObject], 939 flat_weights: List[Optional[internal_types.NativeObject]], 940 flat_features: List[tpu_embedding_v2_utils.FeatureConfig], 941 device_ordinal: int, 942 mode_override: Text 943 ) -> ops.Operation: 944 """Outputs a the enqueue op given the inputs and weights. 945 946 Args: 947 flat_inputs: A list of input tensors. 948 flat_weights: A list of input weights (or None) of the same length as 949 flat_inputs. 950 flat_features: A list of FeatureConfigs of the same length as flat_inputs. 951 device_ordinal: The device to create the enqueue op for. 952 mode_override: A tensor containing the string "train" or "inference". 953 954 Returns: 955 The enqueue op. 956 """ 957 # Combiners are per table, list in the same order as the table order. 958 combiners = [table.combiner for table in self._table_config] 959 960 # These parallel arrays will be the inputs to the enqueue op. 961 # sample_indices for sparse, row_splits for ragged. 962 indices_or_row_splits = [] 963 values = [] 964 weights = [] 965 966 # We have to supply a empty/zero tensor in a list position where we don't 967 # have data (e.g. indices for standard Tensor input, weight when no weight 968 # is specified). We create one op here per call, so that we reduce the 969 # graph size. 970 int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) 971 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 972 973 # In the following loop we insert casts so that everything is either int32 974 # or float32. This is because op inputs which are lists of tensors must be 975 # of the same type within the list. Moreover the CPU implementations of 976 # these ops cast to these types anyway, so we don't lose any data by casting 977 # early. 978 for inp, weight, (path, feature) in zip( 979 flat_inputs, flat_weights, flat_features): 980 if isinstance(inp, ops.Tensor): 981 self._add_data_for_tensor(inp, weight, indices_or_row_splits, values, 982 weights, int_zeros, float_zeros, path) 983 elif isinstance(inp, sparse_tensor.SparseTensor): 984 self._add_data_for_sparse_tensor(inp, weight, indices_or_row_splits, 985 values, weights, int_zeros, 986 float_zeros, path, feature) 987 elif isinstance(inp, ragged_tensor.RaggedTensor): 988 self._add_data_for_ragged_tensor(inp, weight, indices_or_row_splits, 989 values, weights, int_zeros, 990 float_zeros, path, feature) 991 else: 992 raise ValueError("Input {} is of unknown type {}. Please only pass " 993 "Tensor, SparseTensor or RaggedTensor as input to " 994 "enqueue.".format(path, type(inp))) 995 996 return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 997 sample_indices_or_row_splits=indices_or_row_splits, 998 embedding_indices=values, 999 aggregation_weights=weights, 1000 mode_override=mode_override, 1001 device_ordinal=device_ordinal, 1002 combiners=combiners) 1003 1004 def _raise_error_for_incorrect_control_flow_context(self): 1005 """Raises an error if we are not in the TPUReplicateContext.""" 1006 # Do not allow any XLA control flow (i.e. control flow in between a 1007 # TPUStrategy's run call and the call to this function), as we can't 1008 # extract the enqueue from the head when in XLA control flow. 1009 graph = ops.get_default_graph() 1010 in_tpu_ctx = False 1011 while graph is not None: 1012 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 1013 while ctx is not None: 1014 if isinstance(ctx, tpu.TPUReplicateContext): 1015 in_tpu_ctx = True 1016 break 1017 ctx = ctx.outer_context 1018 if in_tpu_ctx: 1019 break 1020 graph = getattr(graph, "outer_graph", None) 1021 if graph != ops.get_default_graph() and in_tpu_ctx: 1022 raise RuntimeError( 1023 "Current graph {} does not match graph which contains " 1024 "TPUReplicateContext {}. This is most likely due to the fact that " 1025 "enqueueing embedding data is called inside control flow or a " 1026 "nested function inside `strategy.run`. This is not supported " 1027 "because outside compilation fails to extract the enqueue ops as " 1028 "head of computation.".format(ops.get_default_graph(), graph)) 1029 return in_tpu_ctx 1030 1031 def _raise_error_for_non_direct_inputs(self, features): 1032 """Checks all tensors in features to see if they are a direct input.""" 1033 1034 # expand_composites here is important: as composite tensors pass through 1035 # tpu.replicate, they get 'flattened' into their component tensors and then 1036 # repacked before being passed to the tpu function. In means that it is the 1037 # component tensors which are produced by an op with the 1038 # "_tpu_input_identity" attribute. 1039 for path, input_tensor in nest.flatten_with_joined_string_paths( 1040 features, expand_composites=True): 1041 if input_tensor.op.type == "Placeholder": 1042 continue 1043 try: 1044 is_input = input_tensor.op.get_attr("_tpu_input_identity") 1045 except ValueError: 1046 is_input = False 1047 if not is_input: 1048 raise ValueError( 1049 "Received input tensor {} which is the output of op {} (type {}) " 1050 "which does not have the `_tpu_input_identity` attr. Please " 1051 "ensure that the inputs to this layer are taken directly from " 1052 "the arguments of the function called by " 1053 "strategy.run. Two possible causes are: dynamic batch size " 1054 "support or you are using a keras layer and are not passing " 1055 "tensors which match the dtype of the `tf.keras.Input`s." 1056 "If you are triggering dynamic batch size support, you can " 1057 "disable it by passing tf.distribute.RunOptions(" 1058 "experimental_enable_dynamic_batch_size=False) to the options " 1059 "argument of strategy.run().".format(path, 1060 input_tensor.op.name, 1061 input_tensor.op.type)) 1062 1063 def _raise_error_for_inputs_not_on_cpu(self, flat_inputs, flat_paths): 1064 """Checks all tensors in features to see are placed on the CPU.""" 1065 1066 def check_device(path, device_string): 1067 spec = tf_device.DeviceSpec.from_string(device_string) 1068 if spec.device_type == "TPU": 1069 raise ValueError( 1070 "Received input tensor {} which is on a TPU input device {}. Input " 1071 "tensors for TPU embeddings must be placed on the CPU. Please " 1072 "ensure that your dataset is prefetching tensors to the host by " 1073 "setting the 'experimental_fetch_to_device' option of the " 1074 "dataset distribution function. See the documentation of the " 1075 "enqueue method for an example.".format(path, device_string)) 1076 1077 # expand_composites here is important, we need to check the device of each 1078 # underlying tensor. 1079 for input_tensor, input_path in zip(flat_inputs, flat_paths): 1080 if nest.is_nested_or_composite(input_tensor): 1081 input_tensors = nest.flatten(input_tensor, expand_composites=True) 1082 else: 1083 input_tensors = [input_tensor] 1084 for t in input_tensors: 1085 if (t.op.type == "Identity" and 1086 t.op.inputs[0].op.type == "TPUReplicatedInput"): 1087 for tensor in t.op.inputs[0].op.inputs: 1088 check_device(input_path, tensor.device) 1089 else: 1090 check_device(input_path, t.device) 1091 1092 def enqueue( 1093 self, 1094 features, 1095 weights=None, 1096 training: bool = True, 1097 name: Optional[Text] = None, 1098 device: Optional[Text] = None): 1099 """Enqueues id tensors for embedding lookup. 1100 1101 This function enqueues a structure of features to be looked up in the 1102 embedding tables. We expect that the input shapes of each of the tensors in 1103 features matches the output shapes set via FeatureConfig or build method 1104 (if any). the output shapes will be auto detected based on the input shapes 1105 with the max_sequence_length or output shape setting in the FeatureConfig. 1106 Note that the output shapes is based on per replica batch size. 1107 If your input dataset is batched to the global batch size and you use 1108 `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` 1109 or if you use `distribute_datasets_from_function` and batch 1110 to the per core batch size computed by the context passed to your input 1111 function, the output shapes should match automatically. 1112 1113 The auto detected the output shapes: 1114 1. For dense tensor, if rank 2 or above, make sure the tensor has last 1115 dimension as 1. The output shape will be the input shape excluding 1116 the last dimension. 1117 2. For sparse tensor, make sure the tensor has rank 2 and above. 1118 a. If feature config has max_sequence_length equals 0 or output shape 1119 set (the max_sequence_length setting will be ignored), the 1120 output shape will be the input shape excluding the last dimension. 1121 b. Otherwize if the tensor is rank 2, the output shape will be input 1122 shape with last dimension set as max_sequence_length. If the 1123 tensor is above rank 2, the output shape will be the input shape 1124 excluding the last dimension and the last dimension of the output 1125 shape will be set to max_sequence_length. 1126 3. For ragged tensor, make sure the tensor has rank 2. 1127 a. If feature config has max_sequence_length equals 0 or output shape 1128 set (the max_sequence_length setting will be ignored), the 1129 output shape will be the input shape excluding the last dimension. 1130 b. Otherwise, the output shape will be the input shape excluding the 1131 last dimension and the last dimension of the output shape will be 1132 set to max_sequence_length. 1133 1134 ```python 1135 strategy = tf.distribute.TPUStrategy(...) 1136 with strategy.scope(): 1137 embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) 1138 1139 distributed_dataset = ( 1140 strategy.distribute_datasets_from_function( 1141 dataset_fn=..., 1142 options=tf.distribute.InputOptions( 1143 experimental_fetch_to_device=False)) 1144 dataset_iterator = iter(distributed_dataset) 1145 1146 @tf.function 1147 def training_step(): 1148 def tpu_step(tpu_features): 1149 with tf.GradientTape() as tape: 1150 activations = embedding.dequeue() 1151 tape.watch(activations) 1152 1153 loss = ... # some computation involving activations 1154 1155 embedding_gradients = tape.gradient(loss, activations) 1156 embedding.apply_gradients(embedding_gradients) 1157 1158 embedding_features, tpu_features = next(dataset_iterator) 1159 embedding.enqueue(embedding_features, training=True) 1160 strategy.run(tpu_step, args=(tpu_features,)) 1161 1162 training_step() 1163 ``` 1164 1165 NOTE: You should specify `training=True` when using 1166 `embedding.apply_gradients` as above and `training=False` when not using 1167 `embedding.apply_gradients` (e.g. for frozen embeddings or when doing 1168 evaluation). 1169 1170 For finer grained control, in the above example the line 1171 1172 ``` 1173 embedding.enqueue(embedding_features, training=True) 1174 ``` 1175 1176 may be replaced with 1177 1178 ``` 1179 per_core_embedding_features = self.strategy.experimental_local_results( 1180 embedding_features) 1181 1182 def per_core_enqueue(ctx): 1183 core_id = ctx.replica_id_in_sync_group 1184 device = strategy.extended.worker_devices[core_id] 1185 embedding.enqueue(per_core_embedding_features[core_id], 1186 device=device) 1187 1188 strategy.experimental_distribute_values_from_function( 1189 per_core_queue_inputs) 1190 ``` 1191 1192 Args: 1193 features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or 1194 `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs 1195 will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` 1196 or `tf.RaggedTensor` is supported per call. 1197 weights: If not `None`, a nested structure of `tf.Tensor`s, 1198 `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except 1199 that the tensors should be of float type (and they will be downcast to 1200 `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the 1201 same for the parallel entries from `features` and similarly for 1202 `tf.RaggedTensor`s we assume the row_splits are the same. 1203 training: Defaults to `True`. If `False`, enqueue the batch as inference 1204 batch (forward pass only). Do not call `apply_gradients` when this is 1205 `False` as this may lead to a deadlock. 1206 name: A name for the underlying op. 1207 device: The device name (e.g. '/task:0/device:TPU:2') where this batch 1208 should be enqueued. This should be set if and only if features is not a 1209 `tf.distribute.DistributedValues` and enqueue is not being called 1210 inside a TPU context (e.g. inside `TPUStrategy.run`). 1211 1212 Raises: 1213 ValueError: When called inside a strategy.run call and input is not 1214 directly taken from the args of the `strategy.run` call. Also if 1215 the size of any sequence in `features` does not match corresponding 1216 sequence in `feature_config`. Similarly for `weights`, if not `None`. 1217 If input shapes of features is unequal or different from a previous 1218 call. 1219 RuntimeError: When called inside a strategy.run call and inside XLA 1220 control flow. If batch_size is not able to be determined and build was 1221 not called. 1222 TypeError: If the type of any sequence in `features` does not match 1223 corresponding sequence in `feature_config`. Similarly for `weights`, if 1224 not `None`. 1225 """ 1226 if not self._using_tpu: 1227 raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " 1228 "created under a TPUStrategy.") 1229 1230 in_tpu_context = self._raise_error_for_incorrect_control_flow_context() 1231 1232 nest.assert_same_structure(self._feature_config, features) 1233 1234 if not self._verify_output_shapes_on_enqueue: 1235 if not self._output_shapes or not self._built: 1236 raise ValueError( 1237 "Configured not to check output shapes on each enqueue() call; please " 1238 "ensure build() was called with output shapes to initialize " 1239 "the TPU for embeddings.") 1240 else: 1241 input_shapes = self._get_input_shapes(features, in_tpu_context) 1242 1243 self._maybe_build(input_shapes) 1244 # If is already built, we still need to check if the output shapes matches 1245 # with the previous ones. 1246 self._check_output_shapes( 1247 self._get_output_shapes_from_input_shapes(input_shapes)) 1248 1249 flat_inputs = nest.flatten(features) 1250 flat_weights = [None] * len(flat_inputs) 1251 if weights is not None: 1252 nest.assert_same_structure(self._feature_config, weights) 1253 flat_weights = nest.flatten(weights) 1254 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 1255 flat_paths, _ = zip(*flat_features) 1256 1257 self._raise_error_for_inputs_not_on_cpu(flat_inputs, flat_paths) 1258 # If we are in a tpu_context, automatically apply outside compilation. 1259 if in_tpu_context: 1260 self._raise_error_for_non_direct_inputs(features) 1261 1262 def generate_enqueue_ops(): 1263 """Generate enqueue ops for outside compilation.""" 1264 # Note that we put array_ops.where_v2 rather than a python if so that 1265 # the op is explicitly create and the constant ops are both in the graph 1266 # even though we don't expect training to be a tensor (and thus generate 1267 # control flow automatically). This need to make it easier to re-write 1268 # the graph later if we need to fix which mode needs to be used. 1269 mode_override = array_ops.where_v2(training, 1270 constant_op.constant("train"), 1271 constant_op.constant("inference")) 1272 # Device ordinal is -1 here, a later rewrite will fix this once the op 1273 # is expanded by outside compilation. 1274 enqueue_op = self._generate_enqueue_op( 1275 flat_inputs, flat_weights, flat_features, device_ordinal=-1, 1276 mode_override=mode_override) 1277 1278 # Apply the name tag to the op. 1279 if name is not None: 1280 _add_key_attr(enqueue_op, name) 1281 1282 # Ensure that this op has outbound control flow, otherwise it won't be 1283 # executed. 1284 ops.get_default_graph().control_outputs.append(enqueue_op) 1285 1286 tpu.outside_compilation(generate_enqueue_ops) 1287 1288 elif device is None: 1289 mode_override = "train" if training else "inference" 1290 # We generate enqueue ops per device, so we need to gather the all 1291 # features for a single device in to a dict. 1292 # We rely here on the fact that the devices in the PerReplica value occur 1293 # in the same (standard) order as self._strategy.extended.worker_devices. 1294 enqueue_ops = [] 1295 for replica_id in range(self._strategy.num_replicas_in_sync): 1296 replica_inputs = distribute_utils.select_replica(replica_id, 1297 flat_inputs) 1298 replica_weights = distribute_utils.select_replica(replica_id, 1299 flat_weights) 1300 tpu_device = self._strategy.extended.worker_devices[replica_id] 1301 # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 1302 # the device ordinal is the last number 1303 device_ordinal = ( 1304 tf_device.DeviceSpec.from_string(tpu_device).device_index) 1305 1306 with ops.device(device_util.get_host_for_device(tpu_device)): 1307 enqueue_op = self._generate_enqueue_op( 1308 replica_inputs, replica_weights, flat_features, 1309 device_ordinal=device_ordinal, mode_override=mode_override) 1310 1311 # Apply the name tag to the op. 1312 if name is not None: 1313 _add_key_attr(enqueue_op, name) 1314 enqueue_ops.append(enqueue_op) 1315 ops.get_default_graph().control_outputs.extend(enqueue_ops) 1316 else: 1317 mode_override = "train" if training else "inference" 1318 device_spec = tf_device.DeviceSpec.from_string(device) 1319 if device_spec.device_type != "TPU": 1320 raise ValueError( 1321 "Non-TPU device {} passed to enqueue.".format(device)) 1322 1323 with ops.device(device_util.get_host_for_device(device)): 1324 enqueue_op = self._generate_enqueue_op( 1325 flat_inputs, flat_weights, flat_features, 1326 device_ordinal=device_spec.device_index, 1327 mode_override=mode_override) 1328 1329 # Apply the name tag to the op. 1330 if name is not None: 1331 _add_key_attr(enqueue_op, name) 1332 ops.get_default_graph().control_outputs.append(enqueue_op) 1333 1334 def _get_input_shapes(self, tensors, 1335 in_tpu_context: bool) -> List[TensorShape]: 1336 """Get the input shapes from the input tensor.""" 1337 input_shapes = [] 1338 for (path, maybe_tensor), feature in zip( 1339 nest.flatten_with_joined_string_paths(tensors), 1340 nest.flatten(self._feature_config)): 1341 if not in_tpu_context: 1342 tensor = distribute_utils.select_replica(0, maybe_tensor) 1343 else: 1344 tensor = maybe_tensor 1345 1346 if isinstance(tensor, ops.Tensor): 1347 input_shapes.append( 1348 self._get_input_shape_for_tensor(tensor, feature, path)) 1349 elif isinstance(tensor, sparse_tensor.SparseTensor): 1350 input_shapes.append( 1351 self._get_input_shape_for_sparse_tensor(tensor, feature, path)) 1352 elif isinstance(tensor, ragged_tensor.RaggedTensor): 1353 input_shapes.append( 1354 self._get_input_shape_for_ragged_tensor(tensor, feature, path)) 1355 return input_shapes 1356 1357 def _get_input_shape_for_tensor(self, tensor, feature, path) -> TensorShape: 1358 """Get the input shape for the dense tensor.""" 1359 shape = tensor.shape.as_list() 1360 if len(shape) < 1: 1361 raise ValueError("Only rank 1 and above dense tensor is supported," 1362 " find rank {} sparse tensor for input {}".format( 1363 len(shape), path)) 1364 if len(shape) > 1 and shape[-1] != 1: 1365 raise ValueError( 1366 "Rank 2 or above dense tensor should have last dimension as 1 " 1367 "as the last dimension will always be reduced. " 1368 "Instead got dense tensor as shape {}".format(shape)) 1369 return TensorShape(shape) 1370 1371 def _get_input_shape_for_sparse_tensor(self, tensor, feature, 1372 path) -> TensorShape: 1373 """Get the input shape for the sparse tensor.""" 1374 shape = tensor.shape.as_list() 1375 # Only 2 and above rank sparse tensor is supported. 1376 if len(shape) < 2: 1377 raise ValueError("Only rank 2 and above sparse tensor is supported," 1378 " find rank {} sparse tensor for input {}".format( 1379 len(shape), path)) 1380 if not feature.output_shape and feature.max_sequence_length > 0: 1381 # If the max_sequence_length is set and the output shape for FeatureConfig 1382 # is not set, we modify the shape of the input feature. Only rank 2 1383 # feature output shape is modified 1384 if len(shape) == 2: 1385 # If the sparse tensor is 2D and max_sequence_length is set, 1386 # we need to add one dimension to the input feature. 1387 shape.insert(len(shape) - 1, feature.max_sequence_length) 1388 1389 return TensorShape(shape) 1390 1391 def _get_input_shape_for_ragged_tensor(self, tensor, feature, 1392 path) -> TensorShape: 1393 """Get the input shape for the ragged tensor.""" 1394 shape = tensor.shape.as_list() 1395 # Only rank 2 ragged tensor is supported. 1396 if len(shape) != 2: 1397 raise ValueError("Only rank 2 ragged tensor is supported," 1398 " find rank {} ragged tensor for input {}".format( 1399 len(shape), path)) 1400 if not feature.output_shape and feature.max_sequence_length > 0: 1401 # If the max_sequence_length is set and the output shape for FeatureConfig 1402 # is not set, add the sequence length as second last dimension of 1403 # the ragged tensor. 1404 shape.insert(len(shape) - 1, feature.max_sequence_length) 1405 1406 return TensorShape(shape) 1407 1408 def _update_output_shapes(self, incoming_output_shapes: List[TensorShape]): 1409 """Update the existing output shapes based on the new output shapes. 1410 1411 The existing output shapes always have higher piority than the new incoming 1412 output shapes. 1413 Args: 1414 incoming_output_shapes: nested structure of TensorShape to override the 1415 existing output shapes. 1416 """ 1417 nest.assert_same_structure(self._output_shapes, incoming_output_shapes) 1418 updated_output_shapes = [] 1419 for old_output_shape, incoming_output_shape in zip(self._output_shapes, 1420 incoming_output_shapes): 1421 if old_output_shape: 1422 updated_output_shapes.append(old_output_shape) 1423 else: 1424 updated_output_shapes.append(incoming_output_shape) 1425 self._output_shapes = updated_output_shapes 1426 1427 def _check_output_shapes(self, incoming_output_shapes: List[TensorShape]): 1428 """Check the incoming output shapes against the output shapes stored.""" 1429 # The incoming output shape should have the same structure with the existing 1430 # output shapes. 1431 nest.assert_same_structure(self._output_shapes, incoming_output_shapes) 1432 1433 for (path, _), old_output_shape, incoming_output_shape in zip( 1434 nest.flatten_with_joined_string_paths(self._feature_config), 1435 self._output_shapes, incoming_output_shapes): 1436 # First check if both shapes are not None. 1437 if old_output_shape and incoming_output_shape: 1438 # We skip the check when the incoming output shape is rank 1 or 2 and 1439 # rank of the old output shape is larger. This can happen for 1440 # (sequence) ragged tensor, we push the check down to the enqueue op. 1441 if (len(incoming_output_shape) == 1 or len(incoming_output_shape) 1442 == 2) and len(old_output_shape) > len(incoming_output_shape): 1443 continue 1444 if len(old_output_shape) != len( 1445 incoming_output_shape) or not self._is_tensor_shape_match( 1446 old_output_shape, incoming_output_shape): 1447 raise ValueError( 1448 f"Inconsistent shape founded for input feature {path}, " 1449 f"Output shape is set to be {old_output_shape}, " 1450 f"But got incoming output shape {incoming_output_shape}") 1451 1452 def _check_output_shapes_fully_defined(self): 1453 """Check if the output shape is fully defined.""" 1454 for (path, _), output_shape in zip( 1455 nest.flatten_with_joined_string_paths(self._feature_config), 1456 self._output_shapes): 1457 if not output_shape.is_fully_defined(): 1458 raise ValueError( 1459 f"Input Feature {path} has output shape set as " 1460 f"{output_shape} which is not fully defined. " 1461 "Please specify the fully defined shape in either FeatureConfig " 1462 "or for the build method.") 1463 1464 def _is_tensor_shape_match(self, shape_a: TensorShape, 1465 shape_b: TensorShape) -> bool: 1466 """Check if shape b matches with shape a.""" 1467 for s_a, s_b in zip(shape_a.as_list(), shape_b.as_list()): 1468 if s_a and s_b and s_a != s_b: 1469 return False 1470 return True 1471 1472 def _get_output_shapes_from_batch_size(self, per_replica_batch_size): 1473 """Get the output shapes from the batch size.""" 1474 output_shapes = [] 1475 for feature in nest.flatten(self._feature_config): 1476 if not feature.output_shape and feature.max_sequence_length > 0: 1477 output_shapes.append( 1478 TensorShape([per_replica_batch_size, feature.max_sequence_length])) 1479 else: 1480 output_shapes.append(TensorShape(per_replica_batch_size)) 1481 return output_shapes 1482 1483 1484@def_function.function 1485def _load_variables_impl( 1486 config: Text, 1487 hosts: List[Tuple[int, Text]], 1488 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1489 table_config: tpu_embedding_v2_utils.TableConfig): 1490 """Load embedding tables to onto TPU for each table and host. 1491 1492 Args: 1493 config: A serialized TPUEmbeddingConfiguration proto. 1494 hosts: A list of CPU devices, on per host. 1495 variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key 1496 is the table name, second key is 'parameters' or the optimizer slot name. 1497 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1498 """ 1499 def select_fn(host_id): 1500 1501 def select_or_zeros(x): 1502 if host_id >= len(x.variables): 1503 # In the edge case where we have more hosts than variables, due to using 1504 # a small number of rows, we load zeros for the later hosts. We copy 1505 # the shape of the first host's variables, which we assume is defined 1506 # because TableConfig guarantees at least one row. 1507 return array_ops.zeros_like(x.variables[0]) 1508 return x.variables[host_id] 1509 1510 return select_or_zeros 1511 1512 for host_id, host in enumerate(hosts): 1513 with ops.device(host): 1514 host_variables = nest.map_structure(select_fn(host_id), variables) 1515 for table in table_config: 1516 table.optimizer._load()( # pylint: disable=protected-access 1517 table_name=table.name, 1518 num_shards=len(hosts), 1519 shard_id=host_id, 1520 config=config, 1521 **host_variables[table.name]) 1522 # Ensure that only the first table/first host gets a config so that we 1523 # don't bloat graph by attaching this large string to each op. 1524 # We have num tables * num hosts of these so for models with a large 1525 # number of tables training on a large slice, this can be an issue. 1526 config = None 1527 1528 1529@def_function.function 1530def _retrieve_variables_impl( 1531 config: Text, 1532 hosts: List[Tuple[int, Text]], 1533 variables: Dict[Text, Dict[Text, tf_variables.Variable]], 1534 table_config: tpu_embedding_v2_utils.TableConfig): 1535 """Retrieve embedding tables from TPU to host memory. 1536 1537 Args: 1538 config: A serialized TPUEmbeddingConfiguration proto. 1539 hosts: A list of all the host CPU devices. 1540 variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key 1541 is the table name, second key is 'parameters' or the optimizer slot name. 1542 table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. 1543 """ 1544 for host_id, host in enumerate(hosts): 1545 with ops.device(host): 1546 for table in table_config: 1547 retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access 1548 table_name=table.name, 1549 num_shards=len(hosts), 1550 shard_id=host_id, 1551 config=config) 1552 # When there are no slot variables (e.g with SGD) this returns a 1553 # single tensor rather than a tuple. In this case we put the tensor in 1554 # a list to make the following code easier to write. 1555 if not isinstance(retrieved, tuple): 1556 retrieved = (retrieved,) 1557 1558 for i, slot in enumerate(["parameters"] + 1559 table.optimizer._slot_names()): # pylint: disable=protected-access 1560 # We must assign the CPU variables the values of tensors that were 1561 # returned from the TPU. 1562 sharded_var = variables[table.name][slot] 1563 if host_id < len(sharded_var.variables): 1564 # In the edge case where we have more hosts than variables, due to 1565 # using a small number of rows, we skip the later hosts. 1566 sharded_var.variables[host_id].assign(retrieved[i]) 1567 # Ensure that only the first table/first host gets a config so that we 1568 # don't bloat graph by attaching this large string to each op. 1569 # We have num tables * num hosts of these so for models with a large 1570 # number of tables training on a large slice, this can be an issue. 1571 config = None 1572 1573 1574def _save_callback(trackables, **unused_kwargs): 1575 for trackable in trackables.values(): 1576 trackable._retrieve_variables() # pylint: disable=protected-access 1577 return [] 1578 1579 1580def _restore_callback(trackables, **unused_kwargs): 1581 for trackable in trackables.values(): 1582 trackable._load_variables() # pylint: disable=protected-access 1583 1584 1585registration.register_tf_checkpoint_saver( 1586 "TPUEmbeddingCallback", 1587 predicate=lambda x: isinstance(x, TPUEmbedding), 1588 save_fn=_save_callback, 1589 restore_fn=_restore_callback, 1590 # Set strict_predicate_restore to `False` to because the isinstance 1591 # predicate check does not pass after a TPUEmbedding object is loaded from 1592 # SavedModel. 1593 strict_predicate_restore=False 1594) 1595 1596 1597def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]: 1598 """Returns a sorted list of CPU devices for the remote jobs. 1599 1600 Args: 1601 strategy: A TPUStrategy object. 1602 1603 Returns: 1604 A sort list of device strings. 1605 """ 1606 list_of_hosts = [] 1607 # Assume this is sorted by task 1608 for tpu_device in strategy.extended.worker_devices: 1609 host = device_util.get_host_for_device(tpu_device) 1610 if host not in list_of_hosts: 1611 list_of_hosts.append(host) 1612 assert len(list_of_hosts) == strategy.extended.num_hosts 1613 return list_of_hosts 1614 1615 1616def extract_variable_info( 1617 kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]: 1618 """Extracts the variable creation attributes from the kwargs. 1619 1620 Args: 1621 kwargs: a dict of keyword arguments that were passed to a variable creator 1622 scope. 1623 1624 Returns: 1625 A tuple of variable name, shape, dtype, initialization function. 1626 """ 1627 if (isinstance(kwargs["initial_value"], functools.partial) and ( 1628 "shape" in kwargs["initial_value"].keywords or 1629 kwargs["initial_value"].args)): 1630 # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. 1631 if "shape" in kwargs["initial_value"].keywords: 1632 shape = kwargs["initial_value"].keywords["shape"] 1633 else: 1634 shape = kwargs["initial_value"].args[0] 1635 return (kwargs["name"], shape, 1636 kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), 1637 kwargs["initial_value"].func) 1638 elif "shape" not in kwargs or kwargs["shape"] is None or not callable( 1639 kwargs["initial_value"]): 1640 raise ValueError( 1641 "Unable to extract initializer function and shape from {}. Please " 1642 "either pass a function that expects a shape and dtype as the " 1643 "initial value for your variable or functools.partial object with " 1644 "the shape and dtype kwargs set. This is needed so that we can " 1645 "initialize the shards of the ShardedVariable locally.".format( 1646 kwargs["initial_value"])) 1647 else: 1648 return (kwargs["name"], kwargs["shape"], kwargs["dtype"], 1649 kwargs["initial_value"]) 1650 1651 1652def make_sharded_variable_creator( 1653 hosts: List[Text]) -> Callable[..., TPUEmbeddingVariable]: 1654 """Makes a sharded variable creator given a list of hosts. 1655 1656 Args: 1657 hosts: a list of tensorflow devices on which to shard the tensors. 1658 1659 Returns: 1660 A variable creator function. 1661 """ 1662 1663 def sharded_variable_creator( 1664 next_creator: Callable[..., tf_variables.Variable], *args, **kwargs): 1665 """The sharded variable creator.""" 1666 kwargs["skip_mirrored_creator"] = True 1667 1668 num_hosts = len(hosts) 1669 name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs) 1670 initial_value = kwargs["initial_value"] 1671 rows = shape[0] 1672 cols = shape[1] 1673 partial_partition = rows % num_hosts 1674 full_rows_per_host = rows // num_hosts 1675 # We partition as if we were using MOD sharding: at least 1676 # `full_rows_per_host` rows to `num_hosts` hosts, where the first 1677 # `partial_partition` hosts get an additional row when the number of rows 1678 # is not cleanly divisible. Note that `full_rows_per_host` may be zero. 1679 partitions = ( 1680 [full_rows_per_host + 1] * partial_partition 1681 + [full_rows_per_host] * (num_hosts - partial_partition)) 1682 variables = [] 1683 sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args 1684 1685 # Keep track of offset for sharding aware initializers. 1686 offset = 0 1687 kwargs["dtype"] = dtype 1688 for i, p in enumerate(partitions): 1689 if p == 0: 1690 # Skip variable creation for empty partitions, resulting from the edge 1691 # case of 'rows < num_hosts'. This is safe because both load/restore 1692 # can handle the missing values. 1693 continue 1694 with ops.device(hosts[i]): 1695 kwargs["name"] = "{}_{}".format(name, i) 1696 kwargs["shape"] = (p, cols) 1697 if sharding_aware: 1698 shard_info = base.ShardInfo(kwargs["shape"], (offset, 0)) 1699 kwargs["initial_value"] = functools.partial( 1700 initial_value, shard_info=shard_info) 1701 offset += p 1702 else: 1703 kwargs["initial_value"] = functools.partial( 1704 unwrapped_initial_value, kwargs["shape"], dtype=dtype) 1705 variables.append(next_creator(*args, **kwargs)) 1706 return TPUEmbeddingVariable(variables, name=name) 1707 return sharded_variable_creator 1708