xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tests/tpu_embedding_v2_optimizer_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPU Embeddings mid level API on TPU."""
16import functools
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.compat import v2_compat
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribution_strategy_context
24from tensorflow.python.eager import backprop
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import init_ops_v2
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variables as tf_variables
32from tensorflow.python.platform import test
33from tensorflow.python.tpu import tpu_embedding
34from tensorflow.python.tpu import tpu_embedding_v2
35from tensorflow.python.tpu import tpu_embedding_v2_utils
36from tensorflow.python.tpu.tests import tpu_embedding_base_test
37
38
39class TPUEmbeddingTest(tpu_embedding_base_test.TPUEmbeddingBaseTest):
40
41  def test_unsupported_optimizer(self):
42    with self.assertRaisesRegex(
43        ValueError, 'is an unsupported optimizer class.'):
44      with self._get_strategy().scope():
45        tpu_embedding_v2.TPUEmbedding(
46            self.feature_config,
47            tpu_embedding.AdagradParameters(learning_rate=0.1))
48
49  def test_variable_learning_rate(self):
50    num_steps = 10
51    num_steps_float = float(num_steps)
52    starting_lr = 1.0
53    ending_lr = 0.5
54
55    strategy = self._get_strategy()
56    num_replicas = strategy.num_replicas_in_sync
57
58    # Create model with Keras.
59    with strategy.scope():
60      step_counter = tf_variables.Variable(0.0, dtypes.float32)
61
62      def lr_function():
63        return gen_math_ops.maximum(
64            ending_lr,
65            starting_lr + ((ending_lr - starting_lr) * step_counter) /
66            num_steps_float)
67
68      optimizer = tpu_embedding_v2_utils.SGD(learning_rate=lr_function)
69      table_config = tpu_embedding_v2_utils.TableConfig(
70          vocabulary_size=num_replicas,
71          dim=4,
72          initializer=init_ops_v2.Constant(np.zeros((num_replicas, 4))),
73          combiner='sum', name='table')
74      mid_level_api = tpu_embedding_v2.TPUEmbedding(
75          feature_config={
76              'feature': tpu_embedding_v2_utils.FeatureConfig(
77                  table=table_config, name='feature')},
78          optimizer=optimizer)
79
80    feature = {
81        'feature': constant_op.constant([0], shape=(1, 1), dtype=dtypes.int32)
82    }
83
84    def input_fn(ctx):
85      del ctx
86      return dataset_ops.DatasetV2.from_tensors(feature).repeat()
87
88    dist = strategy.distribute_datasets_from_function(
89        input_fn,
90        options=distribute_lib.InputOptions(experimental_fetch_to_device=False))
91    dist_iter = iter(dist)
92
93    @def_function.function
94    def test_fn():
95      def step():
96        with backprop.GradientTape() as tape:
97          activations = mid_level_api.dequeue()
98          tape.watch(activations)
99          result = math_ops.reduce_sum(activations['feature'])
100          loss = result / num_replicas
101        grads = tape.gradient(loss, activations)
102        mid_level_api.apply_gradients(grads)
103        return activations['feature']
104
105      mid_level_api.enqueue(next(dist_iter), training=True)
106      return strategy.run(step)
107
108    # Run model.
109    results = []
110    for _ in range(num_steps):
111      result = test_fn()
112      results.append(self._unpack(strategy, result))
113      step_counter.assign_add(1.0)
114
115    # Table is 2 elements wide, per-replica batch size of 1, with id 0.
116    # Loss for the gradient is the sum of the entries divided by the number of
117    # replicas. Thus the per replica gradient is 1/#of replicas for row 0 and no
118    # other updates. The reduced gradient is therefore 1.
119    # Learning rate schedule over num_steps steps:
120    # 1.0 0.95 0.9 0.85 0.8 ...
121    # Since use SGD and the gradient is one, the first row of the table is
122    # [0, 0] [-1.0, -1.0] [-1.95, -1.95] [-2.85, -2.85] ... (the negative
123    # partial sums of the above).
124
125    learning_rates = [starting_lr - (starting_lr - ending_lr) / num_steps * j
126                      for j in range(num_steps)]
127    cumsum = [sum(learning_rates[0:j]) for j in range(num_steps)]
128    goldens = [[[-cumsum[i]] * table_config.dim] * num_replicas
129               for i in range(10)]
130    self.assertAllClose(results, goldens)
131
132  @parameterized.parameters([True, False])
133  def test_optimizer_with_slot_creation_fn(self, use_tpu):
134    def slot_creation_fn(table, slot_names, _):
135      slots = {}
136      for slot in slot_names:
137        slots[slot] = tf_variables.Variable(
138            name='{}_{}'.format(table.name, slot),
139            initial_value=functools.partial(
140                init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32),
141            trainable=False)
142      return slots
143    optimizer = tpu_embedding_v2_utils.Adagrad(
144        learning_rate=0.1,
145        slot_variable_creation_fn=slot_creation_fn)
146    if use_tpu:
147      strategy = self._get_strategy()
148    else:
149      strategy = distribution_strategy_context.get_strategy()
150    with strategy.scope():
151      mid_level = tpu_embedding_v2.TPUEmbedding(
152          feature_config=self.feature_config,
153          optimizer=optimizer)
154      # We aren't going to actually run anything, so the batch_size here does
155      # not matter.
156      mid_level.build(self.batch_size)
157    video_accumulator = mid_level._variables['video']['accumulators']
158    user_accumulator = mid_level._variables['user']['accumulators']
159    if use_tpu:
160      # To check the table contents (ensure that it is zero rather than the
161      # normal initial accumulator value specified to in the optimizer config),
162      # we need to select the underlying table variable on TPU.
163      # We only have one shard on Forge.
164      video_accumulator = video_accumulator.variables[0]
165      user_accumulator = user_accumulator.variables[0]
166
167    self.assertAllClose(video_accumulator.numpy(),
168                        np.zeros((self.table_video.vocabulary_size,
169                                  self.table_video.dim)))
170    self.assertAllClose(user_accumulator.numpy(),
171                        np.zeros((self.table_user.vocabulary_size,
172                                  self.table_user.dim)))
173
174  def test_optimizer_with_slot_creation_fn_non_partial(self):
175    def slot_creation_fn(table, slot_names, _):
176      slots = {}
177      for slot in slot_names:
178        # Note that we don't pass functools.partial here, so on TPU we can't
179        # extract the shape. We expect the error below.
180        slots[slot] = tf_variables.Variable(
181            name='{}_{}'.format(table.name, slot),
182            initial_value=init_ops_v2.Zeros()(shape=table.shape,
183                                              dtype=dtypes.float32),
184            trainable=False)
185      return slots
186    optimizer = tpu_embedding_v2_utils.Adagrad(
187        learning_rate=0.1,
188        slot_variable_creation_fn=slot_creation_fn)
189    strategy = self._get_strategy()
190    with strategy.scope():
191      mid_level_api = tpu_embedding_v2.TPUEmbedding(
192          feature_config=self.feature_config,
193          optimizer=optimizer)
194      with self.assertRaisesRegex(ValueError,
195                                  'Unable to extract initializer function'):
196        # We aren't going to actually run anything, so the batch_size here does
197        # not matter.
198        mid_level_api.build(self.batch_size)
199
200if __name__ == '__main__':
201  v2_compat.enable_v2_behavior()
202  test.main()
203