xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mid level API for TPU Embeddings."""
16
17import functools
18from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
19
20from absl import logging
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
24from tensorflow.python.distribute import device_util
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import distribution_strategy_context
27from tensorflow.python.distribute import sharded_variable
28from tensorflow.python.distribute import tpu_strategy
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import device as tf_device
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.framework.tensor_shape import TensorShape
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.ops import variables as tf_variables
41from tensorflow.python.ops.ragged import ragged_tensor
42from tensorflow.python.saved_model import registration
43from tensorflow.python.saved_model import save_context
44from tensorflow.python.tpu import tpu
45from tensorflow.python.tpu import tpu_embedding_v2_utils
46from tensorflow.python.tpu.ops import tpu_ops
47from tensorflow.python.trackable import autotrackable
48from tensorflow.python.trackable import base
49from tensorflow.python.types import internal as internal_types
50from tensorflow.python.util import compat
51from tensorflow.python.util import nest
52from tensorflow.python.util import tf_inspect
53from tensorflow.python.util.tf_export import tf_export
54
55
56_HOOK_KEY = "TPUEmbedding_saveable"
57_NAME_KEY = "_tpu_embedding_layer"
58
59
60class TPUEmbeddingVariable(sharded_variable.ShardedVariableMixin):
61  """A ShardedVariable class for TPU."""
62
63  @property
64  def _in_graph_mode(self):
65    return self.variables[0]._in_graph_mode  # pylint: disable=protected-access
66
67
68def _add_key_attr(op, name):
69  op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name)))  # pylint: disable=protected-access
70
71
72@tf_export("tpu.experimental.embedding.TPUEmbedding")
73class TPUEmbedding(autotrackable.AutoTrackable):
74  """The TPUEmbedding mid level API.
75
76  NOTE: When instantiated under a TPUStrategy, this class can only be created
77  once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to
78  re-initialize the embedding engine you must re-initialize the tpu as well.
79  Doing this will clear any variables from TPU, so ensure you have checkpointed
80  before you do this. If a further instances of the class are needed,
81  set the `initialize_tpu_embedding` argument to `False`.
82
83  This class can be used to support training large embeddings on TPU. When
84  creating an instance of this class, you must specify the complete set of
85  tables and features you expect to lookup in those tables. See the
86  documentation of `tf.tpu.experimental.embedding.TableConfig` and
87  `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete
88  set of options. We will cover the basic usage here.
89
90  NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object,
91  allowing different features to share the same table:
92
93  ```python
94  table_config_one = tf.tpu.experimental.embedding.TableConfig(
95      vocabulary_size=...,
96      dim=...)
97  table_config_two = tf.tpu.experimental.embedding.TableConfig(
98      vocabulary_size=...,
99      dim=...)
100  feature_config = {
101      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
102          table=table_config_one),
103      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
104          table=table_config_one),
105      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
106          table=table_config_two)}
107  ```
108
109  There are two modes under which the `TPUEmbedding` class can used. This
110  depends on if the class was created under a `TPUStrategy` scope or not.
111
112  Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and
113  `apply_gradients`. We will show examples below of how to use these to train
114  and evaluate your model. Under CPU, we only access to the `embedding_tables`
115  property which allow access to the embedding tables so that you can use them
116  to run model evaluation/prediction on CPU.
117
118  First lets look at the `TPUStrategy` mode. Initial setup looks like:
119
120  ```python
121  strategy = tf.distribute.TPUStrategy(...)
122  with strategy.scope():
123    embedding = tf.tpu.experimental.embedding.TPUEmbedding(
124        feature_config=feature_config,
125        optimizer=tf.tpu.experimental.embedding.SGD(0.1))
126  ```
127
128  When creating a distributed dataset that is to be passed to the enqueue
129  operation a special input option must be specified:
130
131  ```python
132  distributed_dataset = (
133      strategy.distribute_datasets_from_function(
134          dataset_fn=...,
135          options=tf.distribute.InputOptions(
136              experimental_fetch_to_device=False))
137  dataset_iterator = iter(distributed_dataset)
138  ```
139
140  Different feature inputs can have different shapes. For dense and sparse
141  tensor, rank 2 and above is supported. For ragged tensor, although only rank 2
142  is supported, you can specify the output shape to be rank 2 and above. The
143  output shape specified in the FeatureConfig has the first priority. The input
144  shape passed in build method has second priority and the input shapes
145  auto detected from input feature has the lowest priority. The latter two will
146  be converted to output shapes by omitting the last dimension. If the lower
147  priority one has output shapes which don't match the former one. A ValueError
148  will be raised. Only when the former one has undefined output shapes, the
149  latter one can override.
150
151  NOTE: All batches passed to the layer can have different input shapes. But
152  these input shapes need to match with the output shapes set by either
153  `FeatureConfig` or build method except for ragged tensor. Only 2D
154  ragged tensor with output shape set to higher dimensions is allowed as
155  long as the total number of elements matches. All subsequent calls must have
156  the same input shapes. In the event that the input shapes cannot be
157  automatically determined by the enqueue method, you must call
158  the build method with the input shapes or provide output shapes in the
159  `FeatureConfig` to initialize the layer.
160
161  To use this API on TPU you should use a custom training loop. Below is an
162  example of a training and evaluation step:
163
164  ```python
165  @tf.function
166  def training_step(dataset_iterator, num_steps):
167    def tpu_step(tpu_features):
168      with tf.GradientTape() as tape:
169        activations = embedding.dequeue()
170        tape.watch(activations)
171        model_output = model(activations)
172        loss = ...  # some function of labels and model_output
173
174      embedding_gradients = tape.gradient(loss, activations)
175      embedding.apply_gradients(embedding_gradients)
176      # Insert your model gradient and optimizer application here
177
178    for _ in tf.range(num_steps):
179      embedding_features, tpu_features = next(dataset_iterator)
180      embedding.enqueue(embedding_features, training=True)
181      strategy.run(tpu_step, args=(tpu_features, ))
182
183  @tf.function
184  def evalution_step(dataset_iterator, num_steps):
185    def tpu_step(tpu_features):
186      activations = embedding.dequeue()
187      model_output = model(activations)
188      # Insert your evaluation code here.
189
190    for _ in tf.range(num_steps):
191      embedding_features, tpu_features = next(dataset_iterator)
192      embedding.enqueue(embedding_features, training=False)
193      strategy.run(tpu_step, args=(tpu_features, ))
194  ```
195
196  NOTE: The calls to `enqueue` have `training` set to `True` when
197  `embedding.apply_gradients` is used and set to `False` when
198  `embedding.apply_gradients` is not present in the function. If you don't
199  follow this pattern you may cause an error to be raised or the tpu may
200  deadlock.
201
202  In the above examples, we assume that the user has a dataset which returns
203  a tuple where the first element of the tuple matches the structure of what
204  was passed as the `feature_config` argument to the object initializer. Also we
205  utilize `tf.range` to get a `tf.while_loop` in order to increase performance.
206
207  When checkpointing your model, you should include your
208  `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a
209  trackable object and saving it will save the embedding tables and their
210  optimizer slot variables:
211
212  ```python
213  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
214  checkpoint.save(...)
215  ```
216
217  On CPU, only the `embedding_table` property is usable. This will allow you to
218  restore a checkpoint to the object and have access to the table variables:
219
220  ```python
221  model = model_fn(...)
222  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
223      feature_config=feature_config,
224      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
225  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
226  checkpoint.restore(...)
227
228  tables = embedding.embedding_tables
229  ```
230
231  You can now use table in functions like `tf.nn.embedding_lookup` to perform
232  your embedding lookup and pass to your model.
233
234  """
235
236  def __init__(
237      self,
238      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
239      optimizer: Optional[tpu_embedding_v2_utils._Optimizer],  # pylint:disable=protected-access
240      pipeline_execution_with_tensor_core: bool = False):
241    """Creates the TPUEmbedding mid level API object.
242
243    ```python
244    strategy = tf.distribute.TPUStrategy(...)
245    with strategy.scope():
246      embedding = tf.tpu.experimental.embedding.TPUEmbedding(
247          feature_config=tf.tpu.experimental.embedding.FeatureConfig(
248              table=tf.tpu.experimental.embedding.TableConfig(
249                  dim=...,
250                  vocabulary_size=...)))
251    ```
252
253    Args:
254      feature_config: A nested structure of
255        `tf.tpu.experimental.embedding.FeatureConfig` configs.
256      optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`,
257        `tf.tpu.experimental.embedding.Adagrad` or
258        `tf.tpu.experimental.embedding.Adam`. When not created under
259        TPUStrategy may be set to None to avoid the creation of the optimizer
260        slot variables, useful for optimizing memory consumption when exporting
261        the model for serving where slot variables aren't needed.
262      pipeline_execution_with_tensor_core: If True, the TPU embedding
263        computations will overlap with the TensorCore computations (and hence
264        will be one step old). Set to True for improved performance.
265
266    Raises:
267      ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
268      Adam or Adagrad) or None when created under a TPUStrategy.
269    """
270    self._strategy = distribution_strategy_context.get_strategy()
271    self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
272                                                  tpu_strategy.TPUStrategyV2))
273    self._pipeline_execution_with_tensor_core = (
274        pipeline_execution_with_tensor_core)
275
276    self._feature_config = feature_config
277    self._output_shapes = []
278    for feature in nest.flatten(feature_config):
279      self._output_shapes.append(feature.output_shape)
280
281    # The TPU embedding ops are slightly inconsistent with how they refer to
282    # tables:
283    # * The enqueue op takes a parallel list of tensors for input, one of those
284    #   is the table id for the feature which matches the integer index of the
285    #   table in the proto created by _create_config_proto().
286    # * The recv_tpu_embedding_activations op emits lookups per table in the
287    #   order from the config proto.
288    # * The send_tpu_embedding_gradients expects input tensors to be per table
289    #   in the same order as the config proto.
290    # * Per optimizer load and retrieve ops are specified per table and take the
291    #   table name rather than the table id.
292    # Thus we must fix a common order to tables and ensure they have unique
293    # names.
294
295    # Set table order here to the order of the first occurence of the table in a
296    # feature provided by the user. The order of this struct must be fixed
297    # to provide the user with deterministic behavior over multiple
298    # instantiations.
299    self._table_config = []
300    for feature in nest.flatten(feature_config):
301      if feature.table not in self._table_config:
302        self._table_config.append(feature.table)
303
304    # Ensure tables have unique names. Also error check the optimizer as we
305    # specifically don't do that in the TableConfig class to allow high level
306    # APIs that are built on this to use strings/other classes to represent
307    # optimizers (before they are passed to this class).
308    table_names = []
309    for i, table in enumerate(self._table_config):
310      if table.optimizer is None:
311        # TODO(bfontain) Should we allow some sort of optimizer merging here?
312        table.optimizer = optimizer
313      if ((table.optimizer is not None or self._using_tpu) and
314          not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)):  # pylint: disable=protected-access
315        raise ValueError("{} is an unsupported optimizer class. Please pass an "
316                         "instance of one of the optimizer classes under "
317                         "tf.tpu.experimental.embedding.".format(
318                             type(table.optimizer)))
319      if table.name is None:
320        table.name = "table_{}".format(i)
321      if table.name in table_names:
322        raise ValueError("Tables must have a unique name. "
323                         f"Multiple tables with name {table.name} found.")
324      table_names.append(table.name)
325
326    if self._using_tpu:
327      # Extract a list of callable learning rates also in fixed order. Each
328      # table in the confix proto will get a index into this list and we will
329      # pass this list in the same order after evaluation to the
330      # send_tpu_embedding_gradients op.
331      self._dynamic_learning_rates = list({
332          table.optimizer.learning_rate for table in self._table_config if
333          callable(table.optimizer.learning_rate)})
334
335      # We need to list of host devices for the load/retrieve operations.
336      self._hosts = get_list_of_hosts(self._strategy)
337
338    self._built = False
339    self._verify_output_shapes_on_enqueue = True
340
341  def build(self, per_replica_input_shapes=None, per_replica_batch_size=None):  # pylint:disable=g-bare-generic
342    """Create the underlying variables and initializes the TPU for embeddings.
343
344    This method creates the underlying variables (including slot variables). If
345    created under a TPUStrategy, this will also initialize the TPU for
346    embeddings.
347
348    This function will automatically get called by enqueue, which will try to
349    determine your output shapes. If this fails, you must manually
350    call this method before you call enqueue.
351
352    Args:
353      per_replica_input_shapes: A nested structure of The per replica input
354        shapes that matches the structure of the feature config. The input
355        shapes should be the same as the input shape of the feature (except for
356        ragged tensor) Note that it is fixed and the same per replica input
357        shapes must be used for both training and evaluation. If you want to
358        calculate this from the global input shapes, you can use
359        `num_replicas_in_sync` property of your strategy object. May be set to
360        None if not created under a TPUStrategy.
361      per_replica_batch_size: (Deprecated) The per replica batch size that you
362        intend to use. Note that is fixed and the same batch size must be used
363        for both training and evaluation. If you want to calculate this from the
364        global batch size, you can use `num_replicas_in_sync` property of your
365        strategy object. May be set to None if not created under a TPUStrategy.
366
367    Raises:
368      ValueError: If per_replica_input_shapes is inconsistent with the output
369      shapes stored in the feature config or the output shapes get from the
370      input shapes are not fully defined.
371      RuntimeError: If tpu embedding is already initialized on TPU.
372    """
373    if self._built:
374      return
375
376    if self._using_tpu:
377      # If the tpu embedding is already initialized on TPU, raise runtime error.
378      # Below logic is not added in `initialize_system_for_tpu_embedding`
379      # because doing exception control flow in graph mode is difficult.
380      if tpu_ops.is_tpu_embedding_initialized():
381        raise RuntimeError(
382            "TPU is already initialized for embeddings. This may be caused by "
383            "using multiple TPUEmbedding instances in a TPU scope which is "
384            "unsupported")
385      self._get_and_update_output_shapes_from_input(per_replica_input_shapes,
386                                                    per_replica_batch_size)
387
388      self._config_proto = self._create_config_proto()
389
390      logging.info("Initializing TPU Embedding engine.")
391      tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
392
393      @def_function.function
394      def load_config():
395        tpu.initialize_system_for_tpu_embedding(self._config_proto)
396
397      load_config()
398      logging.info("Done initializing TPU Embedding engine.")
399
400    # Create and load variables and slot variables into the TPU.
401    # Note that this is a dict of dicts. Keys to the first dict are table names.
402    # We would prefer to use TableConfigs, but then these variables won't be
403    # properly tracked by the tracking API.
404    self._variables = self._create_variables_and_slots()
405
406    self._built = True
407
408    # This is internally conditioned self._built and self._using_tpu
409    self._load_variables()
410
411  def _maybe_build(self,
412                   output_shapes: Optional[Union[List[int], Iterable]] = None):  # pylint:disable=g-bare-generic
413    if not self._built:
414      # This can be called while tracing a function, so we wrap the
415      # initialization code with init_scope so it runs eagerly, this means that
416      # it will not be included the function graph generated by tracing so that
417      # we can be sure that we only initialize the TPU for embeddings exactly
418      # once.
419      with ops.init_scope():
420        self.build(output_shapes)
421
422  def _get_and_update_output_shapes_from_input(
423      self,
424      per_replica_input_shapes: Optional[List[TensorShape]] = None,
425      per_replica_batch_size: Optional[int] = None):
426    """Get and update the per replica output shapes from the input."""
427    per_replica_output_shapes = None
428    if per_replica_batch_size and per_replica_input_shapes is None:
429      logging.warning(
430          "per_replica_batch_size argument will be deprecated, please specify "
431          "all the input shapes using per_replica_input_shapes argument.")
432      per_replica_output_shapes = self._get_output_shapes_from_batch_size(
433          per_replica_batch_size)
434
435    # Update the input shapes if provided.
436    if per_replica_input_shapes is not None:
437      if isinstance(per_replica_input_shapes, int):
438        logging.warning(
439            "Passing batch size to per_replica_input_shapes argument will be"
440            " deprecated, please specify all the input shapes using"
441            " per_replica_input_shapes argument.")
442        per_replica_output_shapes = self._get_output_shapes_from_batch_size(
443            per_replica_input_shapes)
444      else:
445        nest.assert_same_structure(
446            nest.flatten(per_replica_input_shapes),
447            nest.flatten(self._feature_config))
448
449        # Convert the nested structure to list.
450        per_replica_input_shapes = nest.flatten(per_replica_input_shapes)
451
452        per_replica_output_shapes = self._get_output_shapes_from_input_shapes(
453            per_replica_input_shapes)
454
455    if per_replica_output_shapes is not None:
456
457      # Check the output shapes with existing output shapes setting.
458      self._check_output_shapes(per_replica_output_shapes)
459
460      # Update the output shapes with existing output shapes setting.
461      # This is necessary Because the output shapes might be missing from
462      # the feature config, the usr can set it:
463      #  1. calling the build method
464      #  2. output shapes auto detected when calling the dequeue method for
465      #     for the first time. The dequeue method will call build method
466      #     with the output shapes.
467      # Either these two situations will lead to an update to the existing
468      # output shapes.
469      self._update_output_shapes(per_replica_output_shapes)
470
471    # Check if the output shapes are fully defined. This is required in order
472    # to set them in the feature descriptor field of the tpu embedding config
473    # proto.
474    self._check_output_shapes_fully_defined()
475
476  def _get_output_shapes_from_input_shapes(
477      self, input_shapes: List[TensorShape]) -> List[TensorShape]:
478    """Get output shapes from the flattened input shapes list."""
479    output_shapes = []
480    for input_shape, feature in zip(input_shapes,
481                                    nest.flatten(self._feature_config)):
482      if input_shape.rank is None or input_shape.rank < 1:
483        raise ValueError(
484            "Received input tensor of shape {}. Rank must be 1 and above"
485            .format(input_shape))
486      # Update the input shape with the max sequence length. Only update when
487      # 1. Input feature is 2D ragged or sparse tensor.
488      # 2. Output shape is not set in the feature config and the max sequence
489      #    length is set.
490      if (len(input_shape) == 2 and input_shape[-1] != 1 and
491          not feature.output_shape and feature.max_sequence_length > 0):
492        input_shape_list = input_shape.as_list()
493        input_shape_list.insert(
494            len(input_shape_list) - 1, feature.max_sequence_length)
495        input_shape = TensorShape(input_shape_list)
496      if input_shape.rank == 1:
497        output_shapes.append(input_shape)
498      else:
499        output_shapes.append(input_shape[:-1])
500    return output_shapes
501
502  @property
503  def embedding_tables(
504      self
505  ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
506    """Returns a dict of embedding tables, keyed by `TableConfig`.
507
508    This property only works when the `TPUEmbedding` object is created under a
509    non-TPU strategy. This is intended to be used to for CPU based lookup when
510    creating a serving checkpoint.
511
512    Returns:
513      A dict of embedding tables, keyed by `TableConfig`.
514
515    Raises:
516      RuntimeError: If object was created under a `TPUStrategy`.
517    """
518    # We don't support returning tables on TPU due to their sharded nature and
519    # the fact that when using a TPUStrategy:
520    # 1. Variables are stale and are only updated when a checkpoint is made.
521    # 2. Updating the variables won't affect the actual tables on the TPU.
522    if self._using_tpu:
523      if save_context.in_save_context():
524        return {table: self._variables[table.name]["parameters"].variables[0]
525                for table in self._table_config}
526      raise RuntimeError("Unable to retrieve embedding tables when using a TPU "
527                         "strategy. If you need access, save your model, "
528                         "create this object under a CPU strategy and restore.")
529
530    self._maybe_build(None)
531
532    # Only return the tables and not the slot variables. On CPU this are honest
533    # tf.Variables.
534    return {table: self._variables[table.name]["parameters"]
535            for table in self._table_config}
536
537  def _create_config_proto(
538      self
539  ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration:
540    """Creates the TPUEmbeddingConfiguration proto.
541
542    This proto is used to initialize the TPU embedding engine.
543
544    Returns:
545      A TPUEmbeddingConfiguration proto.
546    """
547
548    config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
549
550    # Map each callable dynamic learning rate to its in index in the list.
551    # The learning rate index is the index of the dynamic learning rate for this
552    # table (if it exists) in the list we created at initialization. We don't
553    # simply create one learning rate index per table as this has extremely bad
554    # performance characteristics. The more separate optimization configurations
555    # we have, the worse the performance will be.
556    learning_rate_index = {r: i for i, r in enumerate(
557        self._dynamic_learning_rates)}
558
559    for table in self._table_config:
560      table_descriptor = config_proto.table_descriptor.add()
561      table_descriptor.name = table.name
562
563      # For small tables, we pad to the number of hosts so that at least one
564      # id will be assigned to each host.
565      table_descriptor.vocabulary_size = max(table.vocabulary_size,
566                                             self._strategy.extended.num_hosts)
567      table_descriptor.dimension = table.dim
568
569      parameters = table_descriptor.optimization_parameters
570
571      # We handle the learning rate separately here and don't allow the
572      # optimization class to handle this, as it doesn't know about dynamic
573      # rates.
574      if callable(table.optimizer.learning_rate):
575        parameters.learning_rate.dynamic.tag = (
576            learning_rate_index[table.optimizer.learning_rate])
577      else:
578        parameters.learning_rate.constant = table.optimizer.learning_rate
579
580      # Use optimizer to handle the rest of the parameters.
581      table.optimizer._set_optimization_parameters(parameters)  # pylint: disable=protected-access
582
583    table_to_id = {table: i for i, table in enumerate(self._table_config)}
584
585    # Set feature descriptor field in the config proto.
586    for feature, output_shape in zip(
587        nest.flatten(self._feature_config), self._output_shapes):
588      feature_descriptor = config_proto.feature_descriptor.add()
589
590      if feature.name:
591        feature_descriptor.name = feature.name
592
593      feature_descriptor.table_id = table_to_id[feature.table]
594      # The input shape of the feature is the actual shape of the input tensor
595      # except the last dimension because the last dimension will always be
596      # reduced.
597      feature_descriptor.input_shape.extend(output_shape.as_list())
598
599    # Always set mode to training, we override the mode during enqueue.
600    config_proto.mode = (
601        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING)
602
603    config_proto.num_hosts = self._strategy.extended.num_hosts
604    config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync
605
606    # TODO(bfontain): Allow users to pick MOD for the host sharding.
607    config_proto.sharding_strategy = (
608        tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT)
609    config_proto.pipeline_execution_with_tensor_core = (
610        self._pipeline_execution_with_tensor_core)
611
612    return config_proto
613
614  def apply_gradients(self, gradients, name: Optional[Text] = None):
615    """Applies the gradient update to the embedding tables.
616
617    If a gradient of `None` is passed in any position of the nested structure,
618    then an gradient update with a zero gradient is applied for that feature.
619    For optimizers like SGD or Adagrad, this is the same as applying no update
620    at all. For lazy Adam and other sparsely applied optimizers with decay,
621    ensure you understand the effect of applying a zero gradient.
622
623    ```python
624    strategy = tf.distribute.TPUStrategy(...)
625    with strategy.scope():
626      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
627
628    distributed_dataset = (
629        strategy.distribute_datasets_from_function(
630            dataset_fn=...,
631            options=tf.distribute.InputOptions(
632                experimental_fetch_to_device=False))
633    dataset_iterator = iter(distributed_dataset)
634
635    @tf.function
636    def training_step():
637      def tpu_step(tpu_features):
638        with tf.GradientTape() as tape:
639          activations = embedding.dequeue()
640          tape.watch(activations)
641
642          loss = ... #  some computation involving activations
643
644        embedding_gradients = tape.gradient(loss, activations)
645        embedding.apply_gradients(embedding_gradients)
646
647      embedding_features, tpu_features = next(dataset_iterator)
648      embedding.enqueue(embedding_features, training=True)
649      strategy.run(tpu_step, args=(tpu_features, ))
650
651    training_step()
652    ```
653
654    Args:
655      gradients: A nested structure of gradients, with structure matching the
656        `feature_config` passed to this object.
657      name: A name for the underlying op.
658
659    Raises:
660      RuntimeError: If called when object wasn't created under a `TPUStrategy`
661        or if not built (either by manually calling build or calling enqueue).
662      ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a
663        `tf.Tensor` of the incorrect shape is passed in. Also if
664        the size of any sequence in `gradients` does not match corresponding
665        sequence in `feature_config`.
666      TypeError: If the type of any sequence in `gradients` does not match
667        corresponding sequence in `feature_config`.
668    """
669    if not self._using_tpu:
670      raise RuntimeError("apply_gradients is not valid when TPUEmbedding "
671                         "object is not created under a TPUStrategy.")
672
673    if not self._built:
674      raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding "
675                         "object. Please either call enqueue first or manually "
676                         "call the build method.")
677
678    nest.assert_same_structure(self._feature_config, gradients)
679    updated_gradients = []
680    for (path, gradient), feature, output_shape in zip(
681        nest.flatten_with_joined_string_paths(gradients),
682        nest.flatten(self._feature_config), self._output_shapes):
683      full_output_shape = list(output_shape) + [feature.table.dim]
684      if gradient is not None and not isinstance(gradient, ops.Tensor):
685        raise ValueError(
686            f"found non-tensor type: {type(gradient)} at path {path}.")
687      if gradient is not None:
688        if gradient.shape != full_output_shape:
689          raise ValueError("Found gradient of shape {} at path {}. Expected "
690                           "shape {}.".format(gradient.shape, path,
691                                              full_output_shape))
692      else:
693        # No gradient for this feature, since we must give a gradient for all
694        # features, pass in a zero tensor here. Note that this is not correct
695        # for all optimizers.
696        logging.warning(
697            "No gradient passed for feature %s, sending zero "
698            "gradient. This may not be correct behavior for certain "
699            "optimizers like Adam.", path)
700        gradient = array_ops.zeros(full_output_shape, dtype=dtypes.float32)
701      # Some gradients can be passed with op which shape is not correctly set.
702      # This ensures that the shape of the gradient is correctly set.
703      updated_gradients.append(
704          array_ops.reshape(gradient, shape=gradient.shape))
705    op = tpu_ops.send_tpu_embedding_gradients(
706        inputs=updated_gradients,
707        learning_rates=[
708            math_ops.cast(fn(), dtype=dtypes.float32)
709            for fn in self._dynamic_learning_rates
710        ],
711        config=self._config_proto.SerializeToString())
712
713    # Apply the name tag to the op.
714    if name is not None:
715      _add_key_attr(op, name)
716
717  def dequeue(self, name: Optional[Text] = None):
718    """Get the embedding results.
719
720    Returns a nested structure of `tf.Tensor` objects, matching the structure of
721    the `feature_config` argument to the `TPUEmbedding` class. The output shape
722    of the tensors is `(*output_shape, dim)`, `dim` is the dimension of the
723    corresponding `TableConfig`. For output_shape, there are three places where
724    it can be set.
725      1. FeatureConfig provided in the __init__ function.
726      2. Per_replica_output_shapes by directly calling the build method
727           after initializing the tpu embedding class.
728      3. Auto detected from the shapes of the input feature.
729    The priority of these places is the exact same order.
730
731    ```python
732    strategy = tf.distribute.TPUStrategy(...)
733    with strategy.scope():
734      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
735
736    distributed_dataset = (
737        strategy.distribute_datasets_from_function(
738            dataset_fn=...,
739            options=tf.distribute.InputOptions(
740                experimental_fetch_to_device=False))
741    dataset_iterator = iter(distributed_dataset)
742
743    @tf.function
744    def training_step():
745      def tpu_step(tpu_features):
746        with tf.GradientTape() as tape:
747          activations = embedding.dequeue()
748          tape.watch(activations)
749
750          loss = ... #  some computation involving activations
751
752        embedding_gradients = tape.gradient(loss, activations)
753        embedding.apply_gradients(embedding_gradients)
754
755      embedding_features, tpu_features = next(dataset_iterator)
756      embedding.enqueue(embedding_features, training=True)
757      strategy.run(tpu_step, args=(tpu_features, ))
758
759    training_step()
760    ```
761
762    Args:
763      name: A name for the underlying op.
764
765    Returns:
766      A nested structure of tensors, with the same structure as `feature_config`
767    passed to this instance of the `TPUEmbedding` object.
768
769    Raises:
770      RuntimeError: If called when object wasn't created under a `TPUStrategy`
771        or if not built (either by manually calling build or calling enqueue).
772    """
773    if not self._using_tpu:
774      raise RuntimeError("dequeue is not valid when TPUEmbedding object is not "
775                         "created under a TPUStrategy.")
776
777    if not self._built:
778      raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. "
779                         "Please either call enqueue first or manually call "
780                         "the build method.")
781
782    # The activations returned by this op are per feature.
783    activations = tpu_ops.recv_tpu_embedding_activations(
784        num_outputs=len(self._config_proto.feature_descriptor),
785        config=self._config_proto.SerializeToString())
786
787    # Apply the name tag to the op.
788    if name is not None:
789      _add_key_attr(activations[0].op, name)
790
791    # Pack the list back into the same nested structure as the features.
792    return nest.pack_sequence_as(self._feature_config, activations)
793
794  def _create_variables_and_slots(
795      self
796  ) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
797    """Create variables for TPU embeddings.
798
799    Note under TPUStrategy this will ensure that all creations happen within a
800    variable creation scope of the sharded variable creator.
801
802    Returns:
803      A dict of dicts. The outer dict is keyed by the table names and the inner
804      dicts are keyed by 'parameters' and the slot variable names.
805    """
806
807    def create_variables(table):
808      """Create all variables."""
809      variable_shape = (table.vocabulary_size, table.dim)
810
811      def getter(name, shape, dtype, initializer, trainable):
812        del shape
813        # _add_variable_with_custom_getter clears the shape sometimes, so we
814        # take the global shape from outside the getter.
815        initial_value = functools.partial(initializer, variable_shape,
816                                          dtype=dtype)
817        return tf_variables.Variable(
818            name=name,
819            initial_value=initial_value,
820            shape=variable_shape,
821            dtype=dtype,
822            trainable=trainable)
823
824      def variable_creator(name, initializer, trainable=True):
825        # use add_variable_with_custom_getter here so that we take advantage of
826        # the checkpoint loading to allow restore before the variables get
827        # created which avoids double initialization.
828        return self._add_variable_with_custom_getter(
829            name=name,
830            initializer=initializer,
831            shape=variable_shape,
832            dtype=dtypes.float32,
833            getter=getter,
834            trainable=trainable)
835
836      parameters = variable_creator(table.name, table.initializer,
837                                    trainable=not self._using_tpu)
838
839      def slot_creator(name, initializer):
840        return variable_creator(table.name + "/" + name,
841                                initializer,
842                                False)
843
844      if table.optimizer is not None:
845        slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
846      else:
847        slot_vars = {}
848      slot_vars["parameters"] = parameters
849      return slot_vars
850
851    # Store tables based on name rather than TableConfig as we can't track
852    # through dicts with non-string keys, i.e. we won't be able to save.
853    variables = {}
854    for table in self._table_config:
855      if not self._using_tpu:
856        variables[table.name] = create_variables(table)
857      else:
858        with variable_scope.variable_creator_scope(
859            make_sharded_variable_creator(self._hosts)):
860          variables[table.name] = create_variables(table)
861
862    return variables
863
864  def _load_variables(self):
865    # Only load the variables if we are:
866    # 1) Using TPU
867    # 2) Variables are created
868    # 3) Not in save context (except if running eagerly)
869    if self._using_tpu and self._built and not (
870        not context.executing_eagerly() and save_context.in_save_context()):
871      _load_variables_impl(self._config_proto.SerializeToString(),
872                           self._hosts,
873                           self._variables,
874                           self._table_config)
875
876  def _retrieve_variables(self):
877    # Only retrieve the variables if we are:
878    # 1) Using TPU
879    # 2) Variables are created
880    # 3) Not in save context (except if running eagerly)
881    if self._using_tpu and self._built and not (
882        not context.executing_eagerly() and save_context.in_save_context()):
883      _retrieve_variables_impl(self._config_proto.SerializeToString(),
884                               self._hosts,
885                               self._variables,
886                               self._table_config)
887
888  # Some helper functions for the below enqueue function.
889  def _add_data_for_tensor(self, tensor, weight, indices, values, weights,
890                           int_zeros, float_zeros, path):
891    if weight is not None:
892      raise ValueError(
893          "Weight specified for dense input {}, which is not allowed. "
894          "Weight will always be 1 in this case.".format(path))
895    # For tensors, there are no indices and no weights.
896    indices.append(int_zeros)
897    values.append(math_ops.cast(array_ops.reshape(tensor, [-1]), dtypes.int64))
898    weights.append(float_zeros)
899
900  def _add_data_for_sparse_tensor(self, tensor, weight, indices, values,
901                                  weights, int_zeros, float_zeros, path,
902                                  feature):
903    sample_indices = math_ops.cast(tensor.indices, dtypes.int32)
904    if tensor.shape.rank == 2:
905      if not feature.output_shape and feature.max_sequence_length > 0:
906        # Add one dimension to the last axis.
907        sample_indices = array_ops.pad(
908            sample_indices, paddings=[[0, 0], [0, 1]])
909    indices.append(sample_indices)
910    values.append(math_ops.cast(tensor.values, dtypes.int64))
911    # If we have weights they must be a SparseTensor.
912    if weight is not None:
913      if not isinstance(weight, sparse_tensor.SparseTensor):
914        raise ValueError("Weight for {} is type {} which does not match "
915                         "type input which is SparseTensor.".format(
916                             path, type(weight)))
917      weights.append(math_ops.cast(weight.values, dtypes.float32))
918    else:
919      weights.append(float_zeros)
920
921  def _add_data_for_ragged_tensor(self, tensor, weight, row_splits, values,
922                                  weights, int_zeros, float_zeros, path,
923                                  feature):
924    row_splits.append(math_ops.cast(tensor.row_splits, dtypes.int32))
925    values.append(math_ops.cast(tensor.values, dtypes.int64))
926    # If we have weights they must be a RaggedTensor.
927    if weight is not None:
928      if not isinstance(weight, ragged_tensor.RaggedTensor):
929        raise ValueError("Weight for {} is type {} which does not match "
930                         "type input which is RaggedTensor.".format(
931                             path, type(weight)))
932      weights.append(math_ops.cast(weight.values, dtypes.float32))
933    else:
934      weights.append(float_zeros)
935
936  def _generate_enqueue_op(
937      self,
938      flat_inputs: List[internal_types.NativeObject],
939      flat_weights: List[Optional[internal_types.NativeObject]],
940      flat_features: List[tpu_embedding_v2_utils.FeatureConfig],
941      device_ordinal: int,
942      mode_override: Text
943  ) -> ops.Operation:
944    """Outputs a the enqueue op given the inputs and weights.
945
946    Args:
947      flat_inputs: A list of input tensors.
948      flat_weights: A list of input weights (or None) of the same length as
949        flat_inputs.
950      flat_features: A list of FeatureConfigs of the same length as flat_inputs.
951      device_ordinal: The device to create the enqueue op for.
952      mode_override: A tensor containing the string "train" or "inference".
953
954    Returns:
955      The enqueue op.
956    """
957    # Combiners are per table, list in the same order as the table order.
958    combiners = [table.combiner for table in self._table_config]
959
960    # These parallel arrays will be the inputs to the enqueue op.
961    # sample_indices for sparse, row_splits for ragged.
962    indices_or_row_splits = []
963    values = []
964    weights = []
965
966    # We have to supply a empty/zero tensor in a list position where we don't
967    # have data (e.g. indices for standard Tensor input, weight when no weight
968    # is specified). We create one op here per call, so that we reduce the
969    # graph size.
970    int_zeros = array_ops.zeros((0,), dtype=dtypes.int32)
971    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
972
973    # In the following loop we insert casts so that everything is either int32
974    # or float32. This is because op inputs which are lists of tensors must be
975    # of the same type within the list. Moreover the CPU implementations of
976    # these ops cast to these types anyway, so we don't lose any data by casting
977    # early.
978    for inp, weight, (path, feature) in zip(
979        flat_inputs, flat_weights, flat_features):
980      if isinstance(inp, ops.Tensor):
981        self._add_data_for_tensor(inp, weight, indices_or_row_splits, values,
982                                  weights, int_zeros, float_zeros, path)
983      elif isinstance(inp, sparse_tensor.SparseTensor):
984        self._add_data_for_sparse_tensor(inp, weight, indices_or_row_splits,
985                                         values, weights, int_zeros,
986                                         float_zeros, path, feature)
987      elif isinstance(inp, ragged_tensor.RaggedTensor):
988        self._add_data_for_ragged_tensor(inp, weight, indices_or_row_splits,
989                                         values, weights, int_zeros,
990                                         float_zeros, path, feature)
991      else:
992        raise ValueError("Input {} is of unknown type {}. Please only pass "
993                         "Tensor, SparseTensor or RaggedTensor as input to "
994                         "enqueue.".format(path, type(inp)))
995
996    return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
997        sample_indices_or_row_splits=indices_or_row_splits,
998        embedding_indices=values,
999        aggregation_weights=weights,
1000        mode_override=mode_override,
1001        device_ordinal=device_ordinal,
1002        combiners=combiners)
1003
1004  def _raise_error_for_incorrect_control_flow_context(self):
1005    """Raises an error if we are not in the TPUReplicateContext."""
1006    # Do not allow any XLA control flow (i.e. control flow in between a
1007    # TPUStrategy's run call and the call to this function), as we can't
1008    # extract the enqueue from the head when in XLA control flow.
1009    graph = ops.get_default_graph()
1010    in_tpu_ctx = False
1011    while graph is not None:
1012      ctx = graph._get_control_flow_context()  # pylint: disable=protected-access
1013      while ctx is not None:
1014        if isinstance(ctx, tpu.TPUReplicateContext):
1015          in_tpu_ctx = True
1016          break
1017        ctx = ctx.outer_context
1018      if in_tpu_ctx:
1019        break
1020      graph = getattr(graph, "outer_graph", None)
1021    if graph != ops.get_default_graph() and in_tpu_ctx:
1022      raise RuntimeError(
1023          "Current graph {} does not match graph which contains "
1024          "TPUReplicateContext {}. This is most likely due to the fact that "
1025          "enqueueing embedding data is called inside control flow or a "
1026          "nested function inside `strategy.run`. This is not supported "
1027          "because outside compilation fails to extract the enqueue ops as "
1028          "head of computation.".format(ops.get_default_graph(), graph))
1029    return in_tpu_ctx
1030
1031  def _raise_error_for_non_direct_inputs(self, features):
1032    """Checks all tensors in features to see if they are a direct input."""
1033
1034    # expand_composites here is important: as composite tensors pass through
1035    # tpu.replicate, they get 'flattened' into their component tensors and then
1036    # repacked before being passed to the tpu function. In means that it is the
1037    # component tensors which are produced by an op with the
1038    # "_tpu_input_identity" attribute.
1039    for path, input_tensor in nest.flatten_with_joined_string_paths(
1040        features, expand_composites=True):
1041      if input_tensor.op.type == "Placeholder":
1042        continue
1043      try:
1044        is_input = input_tensor.op.get_attr("_tpu_input_identity")
1045      except ValueError:
1046        is_input = False
1047      if not is_input:
1048        raise ValueError(
1049            "Received input tensor {} which is the output of op {} (type {}) "
1050            "which does not have the `_tpu_input_identity` attr. Please "
1051            "ensure that the inputs to this layer are taken directly from "
1052            "the arguments of the function called by "
1053            "strategy.run. Two possible causes are: dynamic batch size "
1054            "support or you are using a keras layer and are not passing "
1055            "tensors which match the dtype of the `tf.keras.Input`s."
1056            "If you are triggering dynamic batch size support, you can "
1057            "disable it by passing tf.distribute.RunOptions("
1058            "experimental_enable_dynamic_batch_size=False) to the options "
1059            "argument of strategy.run().".format(path,
1060                                                 input_tensor.op.name,
1061                                                 input_tensor.op.type))
1062
1063  def _raise_error_for_inputs_not_on_cpu(self, flat_inputs, flat_paths):
1064    """Checks all tensors in features to see are placed on the CPU."""
1065
1066    def check_device(path, device_string):
1067      spec = tf_device.DeviceSpec.from_string(device_string)
1068      if spec.device_type == "TPU":
1069        raise ValueError(
1070            "Received input tensor {} which is on a TPU input device {}. Input "
1071            "tensors for TPU embeddings must be placed on the CPU. Please "
1072            "ensure that your dataset is prefetching tensors to the host by "
1073            "setting the 'experimental_fetch_to_device' option of the "
1074            "dataset distribution function. See the documentation of the "
1075            "enqueue method for an example.".format(path, device_string))
1076
1077    # expand_composites here is important, we need to check the device of each
1078    # underlying tensor.
1079    for input_tensor, input_path in zip(flat_inputs, flat_paths):
1080      if nest.is_nested_or_composite(input_tensor):
1081        input_tensors = nest.flatten(input_tensor, expand_composites=True)
1082      else:
1083        input_tensors = [input_tensor]
1084      for t in input_tensors:
1085        if (t.op.type == "Identity" and
1086            t.op.inputs[0].op.type == "TPUReplicatedInput"):
1087          for tensor in t.op.inputs[0].op.inputs:
1088            check_device(input_path, tensor.device)
1089        else:
1090          check_device(input_path, t.device)
1091
1092  def enqueue(
1093      self,
1094      features,
1095      weights=None,
1096      training: bool = True,
1097      name: Optional[Text] = None,
1098      device: Optional[Text] = None):
1099    """Enqueues id tensors for embedding lookup.
1100
1101    This function enqueues a structure of features to be looked up in the
1102    embedding tables. We expect that the input shapes of each of the tensors in
1103    features matches the output shapes set via FeatureConfig or build method
1104    (if any). the output shapes will be auto detected based on the input shapes
1105    with the max_sequence_length or output shape setting in the FeatureConfig.
1106    Note that the output shapes is based on per replica batch size.
1107    If your input dataset is batched to the global batch size and you use
1108    `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
1109    or if you use `distribute_datasets_from_function` and batch
1110    to the per core batch size computed by the context passed to your input
1111    function, the output shapes should match automatically.
1112
1113    The auto detected the output shapes:
1114      1. For dense tensor, if rank 2 or above, make sure the tensor has last
1115         dimension as 1. The output shape will be the input shape excluding
1116         the last dimension.
1117      2. For sparse tensor, make sure the tensor has rank 2 and above.
1118           a. If feature config has max_sequence_length equals 0 or output shape
1119              set (the max_sequence_length setting will be ignored), the
1120              output shape will be the input shape excluding the last dimension.
1121           b. Otherwize if the tensor is rank 2, the output shape will be input
1122              shape  with last dimension set as max_sequence_length. If the
1123              tensor is above rank 2, the output shape will be the input shape
1124              excluding the last dimension and the last dimension of the output
1125              shape will be set to max_sequence_length.
1126      3. For ragged tensor, make sure the tensor has rank 2.
1127           a. If feature config has max_sequence_length equals 0 or output shape
1128              set (the max_sequence_length setting will be ignored), the
1129              output shape will be the input shape excluding the last dimension.
1130           b. Otherwise, the output shape will be the input shape excluding the
1131              last dimension and the last dimension of the output shape will be
1132              set to max_sequence_length.
1133
1134    ```python
1135    strategy = tf.distribute.TPUStrategy(...)
1136    with strategy.scope():
1137      embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
1138
1139    distributed_dataset = (
1140        strategy.distribute_datasets_from_function(
1141            dataset_fn=...,
1142            options=tf.distribute.InputOptions(
1143                experimental_fetch_to_device=False))
1144    dataset_iterator = iter(distributed_dataset)
1145
1146    @tf.function
1147    def training_step():
1148      def tpu_step(tpu_features):
1149        with tf.GradientTape() as tape:
1150          activations = embedding.dequeue()
1151          tape.watch(activations)
1152
1153          loss = ... #  some computation involving activations
1154
1155        embedding_gradients = tape.gradient(loss, activations)
1156        embedding.apply_gradients(embedding_gradients)
1157
1158      embedding_features, tpu_features = next(dataset_iterator)
1159      embedding.enqueue(embedding_features, training=True)
1160      strategy.run(tpu_step, args=(tpu_features,))
1161
1162    training_step()
1163    ```
1164
1165    NOTE: You should specify `training=True` when using
1166    `embedding.apply_gradients` as above and `training=False` when not using
1167    `embedding.apply_gradients` (e.g. for frozen embeddings or when doing
1168    evaluation).
1169
1170    For finer grained control, in the above example the line
1171
1172    ```
1173      embedding.enqueue(embedding_features, training=True)
1174    ```
1175
1176    may be replaced with
1177
1178    ```
1179      per_core_embedding_features = self.strategy.experimental_local_results(
1180          embedding_features)
1181
1182      def per_core_enqueue(ctx):
1183        core_id = ctx.replica_id_in_sync_group
1184        device = strategy.extended.worker_devices[core_id]
1185        embedding.enqueue(per_core_embedding_features[core_id],
1186                          device=device)
1187
1188      strategy.experimental_distribute_values_from_function(
1189          per_core_queue_inputs)
1190    ```
1191
1192    Args:
1193      features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or
1194        `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs
1195        will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor`
1196        or `tf.RaggedTensor` is supported per call.
1197      weights: If not `None`, a nested structure of `tf.Tensor`s,
1198        `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except
1199        that the tensors should be of float type (and they will be downcast to
1200        `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the
1201        same for the parallel entries from `features` and similarly for
1202        `tf.RaggedTensor`s we assume the row_splits are the same.
1203      training: Defaults to `True`. If `False`, enqueue the batch as inference
1204        batch (forward pass only). Do not call `apply_gradients` when this is
1205        `False` as this may lead to a deadlock.
1206       name: A name for the underlying op.
1207       device: The device name (e.g. '/task:0/device:TPU:2') where this batch
1208         should be enqueued. This should be set if and only if features is not a
1209         `tf.distribute.DistributedValues` and enqueue is not being called
1210         inside a TPU context (e.g. inside `TPUStrategy.run`).
1211
1212    Raises:
1213      ValueError: When called inside a strategy.run call and input is not
1214        directly taken from the args of the `strategy.run` call. Also if
1215        the size of any sequence in `features` does not match corresponding
1216        sequence in `feature_config`. Similarly for `weights`, if not `None`.
1217        If input shapes of features is unequal or different from a previous
1218        call.
1219      RuntimeError: When called inside a strategy.run call and inside XLA
1220        control flow. If batch_size is not able to be determined and build was
1221        not called.
1222      TypeError: If the type of any sequence in `features` does not match
1223        corresponding sequence in `feature_config`. Similarly for `weights`, if
1224        not `None`.
1225    """
1226    if not self._using_tpu:
1227      raise RuntimeError("enqueue is not valid when TPUEmbedding object is not "
1228                         "created under a TPUStrategy.")
1229
1230    in_tpu_context = self._raise_error_for_incorrect_control_flow_context()
1231
1232    nest.assert_same_structure(self._feature_config, features)
1233
1234    if not self._verify_output_shapes_on_enqueue:
1235      if not self._output_shapes or not self._built:
1236        raise ValueError(
1237            "Configured not to check output shapes on each enqueue() call; please "
1238            "ensure build() was called with output shapes to initialize "
1239            "the TPU for embeddings.")
1240    else:
1241      input_shapes = self._get_input_shapes(features, in_tpu_context)
1242
1243      self._maybe_build(input_shapes)
1244      # If is already built, we still need to check if the output shapes matches
1245      # with the previous ones.
1246      self._check_output_shapes(
1247          self._get_output_shapes_from_input_shapes(input_shapes))
1248
1249    flat_inputs = nest.flatten(features)
1250    flat_weights = [None] * len(flat_inputs)
1251    if weights is not None:
1252      nest.assert_same_structure(self._feature_config, weights)
1253      flat_weights = nest.flatten(weights)
1254    flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
1255    flat_paths, _ = zip(*flat_features)
1256
1257    self._raise_error_for_inputs_not_on_cpu(flat_inputs, flat_paths)
1258    # If we are in a tpu_context, automatically apply outside compilation.
1259    if in_tpu_context:
1260      self._raise_error_for_non_direct_inputs(features)
1261
1262      def generate_enqueue_ops():
1263        """Generate enqueue ops for outside compilation."""
1264        # Note that we put array_ops.where_v2 rather than a python if so that
1265        # the op is explicitly create and the constant ops are both in the graph
1266        # even though we don't expect training to be a tensor (and thus generate
1267        # control flow automatically). This need to make it easier to re-write
1268        # the graph later if we need to fix which mode needs to be used.
1269        mode_override = array_ops.where_v2(training,
1270                                           constant_op.constant("train"),
1271                                           constant_op.constant("inference"))
1272        # Device ordinal is -1 here, a later rewrite will fix this once the op
1273        # is expanded by outside compilation.
1274        enqueue_op = self._generate_enqueue_op(
1275            flat_inputs, flat_weights, flat_features, device_ordinal=-1,
1276            mode_override=mode_override)
1277
1278        # Apply the name tag to the op.
1279        if name is not None:
1280          _add_key_attr(enqueue_op, name)
1281
1282        # Ensure that this op has outbound control flow, otherwise it won't be
1283        # executed.
1284        ops.get_default_graph().control_outputs.append(enqueue_op)
1285
1286      tpu.outside_compilation(generate_enqueue_ops)
1287
1288    elif device is None:
1289      mode_override = "train" if training else "inference"
1290      # We generate enqueue ops per device, so we need to gather the all
1291      # features for a single device in to a dict.
1292      # We rely here on the fact that the devices in the PerReplica value occur
1293      # in the same (standard) order as self._strategy.extended.worker_devices.
1294      enqueue_ops = []
1295      for replica_id in range(self._strategy.num_replicas_in_sync):
1296        replica_inputs = distribute_utils.select_replica(replica_id,
1297                                                         flat_inputs)
1298        replica_weights = distribute_utils.select_replica(replica_id,
1299                                                          flat_weights)
1300        tpu_device = self._strategy.extended.worker_devices[replica_id]
1301        # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0
1302        # the device ordinal is the last number
1303        device_ordinal = (
1304            tf_device.DeviceSpec.from_string(tpu_device).device_index)
1305
1306        with ops.device(device_util.get_host_for_device(tpu_device)):
1307          enqueue_op = self._generate_enqueue_op(
1308              replica_inputs, replica_weights, flat_features,
1309              device_ordinal=device_ordinal, mode_override=mode_override)
1310
1311          # Apply the name tag to the op.
1312          if name is not None:
1313            _add_key_attr(enqueue_op, name)
1314          enqueue_ops.append(enqueue_op)
1315      ops.get_default_graph().control_outputs.extend(enqueue_ops)
1316    else:
1317      mode_override = "train" if training else "inference"
1318      device_spec = tf_device.DeviceSpec.from_string(device)
1319      if device_spec.device_type != "TPU":
1320        raise ValueError(
1321            "Non-TPU device {} passed to enqueue.".format(device))
1322
1323      with ops.device(device_util.get_host_for_device(device)):
1324        enqueue_op = self._generate_enqueue_op(
1325            flat_inputs, flat_weights, flat_features,
1326            device_ordinal=device_spec.device_index,
1327            mode_override=mode_override)
1328
1329        # Apply the name tag to the op.
1330        if name is not None:
1331          _add_key_attr(enqueue_op, name)
1332        ops.get_default_graph().control_outputs.append(enqueue_op)
1333
1334  def _get_input_shapes(self, tensors,
1335                        in_tpu_context: bool) -> List[TensorShape]:
1336    """Get the input shapes from the input tensor."""
1337    input_shapes = []
1338    for (path, maybe_tensor), feature in zip(
1339        nest.flatten_with_joined_string_paths(tensors),
1340        nest.flatten(self._feature_config)):
1341      if not in_tpu_context:
1342        tensor = distribute_utils.select_replica(0, maybe_tensor)
1343      else:
1344        tensor = maybe_tensor
1345
1346      if isinstance(tensor, ops.Tensor):
1347        input_shapes.append(
1348            self._get_input_shape_for_tensor(tensor, feature, path))
1349      elif isinstance(tensor, sparse_tensor.SparseTensor):
1350        input_shapes.append(
1351            self._get_input_shape_for_sparse_tensor(tensor, feature, path))
1352      elif isinstance(tensor, ragged_tensor.RaggedTensor):
1353        input_shapes.append(
1354            self._get_input_shape_for_ragged_tensor(tensor, feature, path))
1355    return input_shapes
1356
1357  def _get_input_shape_for_tensor(self, tensor, feature, path) -> TensorShape:
1358    """Get the input shape for the dense tensor."""
1359    shape = tensor.shape.as_list()
1360    if len(shape) < 1:
1361      raise ValueError("Only rank 1 and above dense tensor is supported,"
1362                       " find rank {} sparse tensor for input {}".format(
1363                           len(shape), path))
1364    if len(shape) > 1 and shape[-1] != 1:
1365      raise ValueError(
1366          "Rank 2 or above dense tensor should have last dimension as 1 "
1367          "as the last dimension will always be reduced. "
1368          "Instead got dense tensor as shape {}".format(shape))
1369    return TensorShape(shape)
1370
1371  def _get_input_shape_for_sparse_tensor(self, tensor, feature,
1372                                         path) -> TensorShape:
1373    """Get the input shape for the sparse tensor."""
1374    shape = tensor.shape.as_list()
1375    # Only 2 and above rank sparse tensor is supported.
1376    if len(shape) < 2:
1377      raise ValueError("Only rank 2 and above sparse tensor is supported,"
1378                       " find rank {} sparse tensor for input {}".format(
1379                           len(shape), path))
1380    if not feature.output_shape and feature.max_sequence_length > 0:
1381      # If the max_sequence_length is set and the output shape for FeatureConfig
1382      # is not set, we modify the shape of the input feature. Only rank 2
1383      # feature output shape is modified
1384      if len(shape) == 2:
1385        # If the sparse tensor is 2D and max_sequence_length is set,
1386        # we need to add one dimension to the input feature.
1387        shape.insert(len(shape) - 1, feature.max_sequence_length)
1388
1389    return TensorShape(shape)
1390
1391  def _get_input_shape_for_ragged_tensor(self, tensor, feature,
1392                                         path) -> TensorShape:
1393    """Get the input shape for the ragged tensor."""
1394    shape = tensor.shape.as_list()
1395    # Only rank 2 ragged tensor is supported.
1396    if len(shape) != 2:
1397      raise ValueError("Only rank 2 ragged tensor is supported,"
1398                       " find rank {} ragged tensor for input {}".format(
1399                           len(shape), path))
1400    if not feature.output_shape and feature.max_sequence_length > 0:
1401      # If the max_sequence_length is set and the output shape for FeatureConfig
1402      # is not set, add the sequence length as second last dimension of
1403      # the ragged tensor.
1404      shape.insert(len(shape) - 1, feature.max_sequence_length)
1405
1406    return TensorShape(shape)
1407
1408  def _update_output_shapes(self, incoming_output_shapes: List[TensorShape]):
1409    """Update the existing output shapes based on the new output shapes.
1410
1411    The existing output shapes always have higher piority than the new incoming
1412    output shapes.
1413    Args:
1414      incoming_output_shapes: nested structure of TensorShape to override the
1415        existing output shapes.
1416    """
1417    nest.assert_same_structure(self._output_shapes, incoming_output_shapes)
1418    updated_output_shapes = []
1419    for old_output_shape, incoming_output_shape in zip(self._output_shapes,
1420                                                       incoming_output_shapes):
1421      if old_output_shape:
1422        updated_output_shapes.append(old_output_shape)
1423      else:
1424        updated_output_shapes.append(incoming_output_shape)
1425    self._output_shapes = updated_output_shapes
1426
1427  def _check_output_shapes(self, incoming_output_shapes: List[TensorShape]):
1428    """Check the incoming output shapes against the output shapes stored."""
1429    # The incoming output shape should have the same structure with the existing
1430    # output shapes.
1431    nest.assert_same_structure(self._output_shapes, incoming_output_shapes)
1432
1433    for (path, _), old_output_shape, incoming_output_shape in zip(
1434        nest.flatten_with_joined_string_paths(self._feature_config),
1435        self._output_shapes, incoming_output_shapes):
1436      # First check if both shapes are not None.
1437      if old_output_shape and incoming_output_shape:
1438        # We skip the check when the incoming output shape is rank 1 or 2 and
1439        # rank of the old output shape is larger. This can happen for
1440        # (sequence) ragged tensor, we push the check down to the enqueue op.
1441        if (len(incoming_output_shape) == 1 or len(incoming_output_shape)
1442            == 2) and len(old_output_shape) > len(incoming_output_shape):
1443          continue
1444        if len(old_output_shape) != len(
1445            incoming_output_shape) or not self._is_tensor_shape_match(
1446                old_output_shape, incoming_output_shape):
1447          raise ValueError(
1448              f"Inconsistent shape founded for input feature {path}, "
1449              f"Output shape is set to be {old_output_shape}, "
1450              f"But got incoming output shape {incoming_output_shape}")
1451
1452  def _check_output_shapes_fully_defined(self):
1453    """Check if the output shape is fully defined."""
1454    for (path, _), output_shape in zip(
1455        nest.flatten_with_joined_string_paths(self._feature_config),
1456        self._output_shapes):
1457      if not output_shape.is_fully_defined():
1458        raise ValueError(
1459            f"Input Feature {path} has output shape set as "
1460            f"{output_shape} which is not fully defined. "
1461            "Please specify the fully defined shape in either FeatureConfig "
1462            "or for the build method.")
1463
1464  def _is_tensor_shape_match(self, shape_a: TensorShape,
1465                             shape_b: TensorShape) -> bool:
1466    """Check if shape b matches with shape a."""
1467    for s_a, s_b in zip(shape_a.as_list(), shape_b.as_list()):
1468      if s_a and s_b and s_a != s_b:
1469        return False
1470    return True
1471
1472  def _get_output_shapes_from_batch_size(self, per_replica_batch_size):
1473    """Get the output shapes from the batch size."""
1474    output_shapes = []
1475    for feature in nest.flatten(self._feature_config):
1476      if not feature.output_shape and feature.max_sequence_length > 0:
1477        output_shapes.append(
1478            TensorShape([per_replica_batch_size, feature.max_sequence_length]))
1479      else:
1480        output_shapes.append(TensorShape(per_replica_batch_size))
1481    return output_shapes
1482
1483
1484@def_function.function
1485def _load_variables_impl(
1486    config: Text,
1487    hosts: List[Tuple[int, Text]],
1488    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1489    table_config: tpu_embedding_v2_utils.TableConfig):
1490  """Load embedding tables to onto TPU for each table and host.
1491
1492  Args:
1493    config: A serialized TPUEmbeddingConfiguration proto.
1494    hosts: A list of CPU devices, on per host.
1495    variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key
1496      is the table name, second key is 'parameters' or the optimizer slot name.
1497    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1498  """
1499  def select_fn(host_id):
1500
1501    def select_or_zeros(x):
1502      if host_id >= len(x.variables):
1503        # In the edge case where we have more hosts than variables, due to using
1504        # a small number of rows, we load zeros for the later hosts. We copy
1505        # the shape of the first host's variables, which we assume is defined
1506        # because TableConfig guarantees at least one row.
1507        return array_ops.zeros_like(x.variables[0])
1508      return x.variables[host_id]
1509
1510    return select_or_zeros
1511
1512  for host_id, host in enumerate(hosts):
1513    with ops.device(host):
1514      host_variables = nest.map_structure(select_fn(host_id), variables)
1515      for table in table_config:
1516        table.optimizer._load()(  # pylint: disable=protected-access
1517            table_name=table.name,
1518            num_shards=len(hosts),
1519            shard_id=host_id,
1520            config=config,
1521            **host_variables[table.name])
1522        # Ensure that only the first table/first host gets a config so that we
1523        # don't bloat graph by attaching this large string to each op.
1524        # We have num tables * num hosts of these so for models with a large
1525        # number of tables training on a large slice, this can be an issue.
1526        config = None
1527
1528
1529@def_function.function
1530def _retrieve_variables_impl(
1531    config: Text,
1532    hosts: List[Tuple[int, Text]],
1533    variables: Dict[Text, Dict[Text, tf_variables.Variable]],
1534    table_config: tpu_embedding_v2_utils.TableConfig):
1535  """Retrieve embedding tables from TPU to host memory.
1536
1537  Args:
1538    config: A serialized TPUEmbeddingConfiguration proto.
1539    hosts: A list of all the host CPU devices.
1540    variables: A dictionary of dictionaries of TPUEmbeddingVariables. First key
1541      is the table name, second key is 'parameters' or the optimizer slot name.
1542    table_config: A list of tf.tpu.experimental.embedding.TableConfig objects.
1543  """
1544  for host_id, host in enumerate(hosts):
1545    with ops.device(host):
1546      for table in table_config:
1547        retrieved = table.optimizer._retrieve()(  # pylint: disable=protected-access
1548            table_name=table.name,
1549            num_shards=len(hosts),
1550            shard_id=host_id,
1551            config=config)
1552        # When there are no slot variables (e.g with SGD) this returns a
1553        # single tensor rather than a tuple. In this case we put the tensor in
1554        # a list to make the following code easier to write.
1555        if not isinstance(retrieved, tuple):
1556          retrieved = (retrieved,)
1557
1558        for i, slot in enumerate(["parameters"] +
1559                                 table.optimizer._slot_names()):  # pylint: disable=protected-access
1560          # We must assign the CPU variables the values of tensors that were
1561          # returned from the TPU.
1562          sharded_var = variables[table.name][slot]
1563          if host_id < len(sharded_var.variables):
1564            # In the edge case where we have more hosts than variables, due to
1565            # using a small number of rows, we skip the later hosts.
1566            sharded_var.variables[host_id].assign(retrieved[i])
1567        # Ensure that only the first table/first host gets a config so that we
1568        # don't bloat graph by attaching this large string to each op.
1569        # We have num tables * num hosts of these so for models with a large
1570        # number of tables training on a large slice, this can be an issue.
1571        config = None
1572
1573
1574def _save_callback(trackables, **unused_kwargs):
1575  for trackable in trackables.values():
1576    trackable._retrieve_variables()  # pylint: disable=protected-access
1577  return []
1578
1579
1580def _restore_callback(trackables, **unused_kwargs):
1581  for trackable in trackables.values():
1582    trackable._load_variables()  # pylint: disable=protected-access
1583
1584
1585registration.register_tf_checkpoint_saver(
1586    "TPUEmbeddingCallback",
1587    predicate=lambda x: isinstance(x, TPUEmbedding),
1588    save_fn=_save_callback,
1589    restore_fn=_restore_callback,
1590    # Set strict_predicate_restore to `False` to because the isinstance
1591    # predicate check does not pass after a TPUEmbedding object is loaded from
1592    # SavedModel.
1593    strict_predicate_restore=False
1594)
1595
1596
1597def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]:
1598  """Returns a sorted list of CPU devices for the remote jobs.
1599
1600  Args:
1601    strategy: A TPUStrategy object.
1602
1603  Returns:
1604    A sort list of device strings.
1605  """
1606  list_of_hosts = []
1607  # Assume this is sorted by task
1608  for tpu_device in strategy.extended.worker_devices:
1609    host = device_util.get_host_for_device(tpu_device)
1610    if host not in list_of_hosts:
1611      list_of_hosts.append(host)
1612  assert len(list_of_hosts) == strategy.extended.num_hosts
1613  return list_of_hosts
1614
1615
1616def extract_variable_info(
1617    kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]:
1618  """Extracts the variable creation attributes from the kwargs.
1619
1620  Args:
1621    kwargs: a dict of keyword arguments that were passed to a variable creator
1622      scope.
1623
1624  Returns:
1625    A tuple of variable name, shape, dtype, initialization function.
1626  """
1627  if (isinstance(kwargs["initial_value"], functools.partial) and (
1628      "shape" in kwargs["initial_value"].keywords or
1629      kwargs["initial_value"].args)):
1630    # Sometimes shape is passed positionally, sometimes it's passed as a kwarg.
1631    if "shape" in kwargs["initial_value"].keywords:
1632      shape = kwargs["initial_value"].keywords["shape"]
1633    else:
1634      shape = kwargs["initial_value"].args[0]
1635    return (kwargs["name"], shape,
1636            kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]),
1637            kwargs["initial_value"].func)
1638  elif "shape" not in kwargs or kwargs["shape"] is None or not callable(
1639      kwargs["initial_value"]):
1640    raise ValueError(
1641        "Unable to extract initializer function and shape from {}. Please "
1642        "either pass a function that expects a shape and dtype as the "
1643        "initial value for your variable or functools.partial object with "
1644        "the shape and dtype kwargs set. This is needed so that we can "
1645        "initialize the shards of the ShardedVariable locally.".format(
1646            kwargs["initial_value"]))
1647  else:
1648    return (kwargs["name"], kwargs["shape"], kwargs["dtype"],
1649            kwargs["initial_value"])
1650
1651
1652def make_sharded_variable_creator(
1653    hosts: List[Text]) -> Callable[..., TPUEmbeddingVariable]:
1654  """Makes a sharded variable creator given a list of hosts.
1655
1656  Args:
1657    hosts: a list of tensorflow devices on which to shard the tensors.
1658
1659  Returns:
1660    A variable creator function.
1661  """
1662
1663  def sharded_variable_creator(
1664      next_creator: Callable[..., tf_variables.Variable], *args, **kwargs):
1665    """The sharded variable creator."""
1666    kwargs["skip_mirrored_creator"] = True
1667
1668    num_hosts = len(hosts)
1669    name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs)
1670    initial_value = kwargs["initial_value"]
1671    rows = shape[0]
1672    cols = shape[1]
1673    partial_partition = rows % num_hosts
1674    full_rows_per_host = rows // num_hosts
1675    # We partition as if we were using MOD sharding: at least
1676    # `full_rows_per_host` rows to `num_hosts` hosts, where the first
1677    # `partial_partition` hosts get an additional row when the number of rows
1678    # is not cleanly divisible. Note that `full_rows_per_host` may be zero.
1679    partitions = (
1680        [full_rows_per_host + 1] * partial_partition
1681        + [full_rows_per_host] * (num_hosts - partial_partition))
1682    variables = []
1683    sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args
1684
1685    # Keep track of offset for sharding aware initializers.
1686    offset = 0
1687    kwargs["dtype"] = dtype
1688    for i, p in enumerate(partitions):
1689      if p == 0:
1690        # Skip variable creation for empty partitions, resulting from the edge
1691        # case of 'rows < num_hosts'. This is safe because both load/restore
1692        # can handle the missing values.
1693        continue
1694      with ops.device(hosts[i]):
1695        kwargs["name"] = "{}_{}".format(name, i)
1696        kwargs["shape"] = (p, cols)
1697        if sharding_aware:
1698          shard_info = base.ShardInfo(kwargs["shape"], (offset, 0))
1699          kwargs["initial_value"] = functools.partial(
1700              initial_value, shard_info=shard_info)
1701          offset += p
1702        else:
1703          kwargs["initial_value"] = functools.partial(
1704              unwrapped_initial_value, kwargs["shape"], dtype=dtype)
1705        variables.append(next_creator(*args, **kwargs))
1706    return TPUEmbeddingVariable(variables, name=name)
1707  return sharded_variable_creator
1708