xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/optimizer_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional test for optimizer."""
16
17from tensorflow.python.distribute import cross_device_ops
18from tensorflow.python.distribute import distribute_utils
19from tensorflow.python.distribute import mirrored_strategy
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import clip_ops
26from tensorflow.python.ops import gradients_util
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import test
31from tensorflow.python.training import adam
32from tensorflow.python.training import gradient_descent
33
34
35class OptimizerTest(test.TestCase):
36
37  @test_util.run_in_graph_and_eager_modes
38  def testBasic(self):
39    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
40      # Note that we name the variables uniquely here since the variables don't
41      # seem to be getting deleted at the end of the loop.
42      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
43                                                    name='a_%d' % i)
44      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
45                                                    name='b_%d' % i)
46      def loss():
47        return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
48      # Note that for eager execution, minimize expects a function instead of a
49      # Tensor.
50      global_step = resource_variable_ops.ResourceVariable(
51          array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
52      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
53
54      self.evaluate(variables.global_variables_initializer())
55      # Fetch params to validate initial values
56      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
57      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
58      # Run 1 step of sgd through optimizer
59      opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
60      self.evaluate(opt_op)
61      # Validate updated params
62      self.assertAllClose([-14., -13.], self.evaluate(var0))
63      self.assertAllClose([-6., -5.], self.evaluate(var1))
64
65  @test_util.run_deprecated_v1
66  def testAggregationMethod(self):
67    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
68      with self.cached_session():
69        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
70        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
71        cost = 5 * var0 + 3 * var1
72        global_step = variables.Variable(
73            array_ops.zeros([], dtypes.int64), name='global_step')
74        sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
75        opt_op = sgd_op.minimize(
76            cost,
77            global_step, [var0, var1],
78            aggregation_method=gradients_util.AggregationMethod.
79            EXPERIMENTAL_ACCUMULATE_N)
80
81        self.evaluate(variables.global_variables_initializer())
82        # Fetch params to validate initial values
83        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
84        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
85        # Run 1 step of sgd through optimizer
86        opt_op.run()
87        # Validate updated params
88        self.assertAllClose([-14., -13.], self.evaluate(var0))
89        self.assertAllClose([-6., -5.], self.evaluate(var1))
90
91  @test_util.run_deprecated_v1
92  def testPrecomputedGradient(self):
93    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
94      with self.cached_session():
95        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
96        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
97        cost = 5 * var0 + 3 * var1
98        grad_loss = constant_op.constant([42, -42], dtype=dtype)
99        global_step = variables.Variable(
100            array_ops.zeros([], dtypes.int64), name='global_step')
101        sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
102        opt_op = sgd_op.minimize(
103            cost, global_step, [var0, var1], grad_loss=grad_loss)
104
105        self.evaluate(variables.global_variables_initializer())
106        # Fetch params to validate initial values
107        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
108        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
109        # Run 1 step of sgd through optimizer
110        opt_op.run()
111        # Validate updated params
112        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
113                            self.evaluate(var0))
114        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
115                            self.evaluate(var1))
116
117  @test_util.run_in_graph_and_eager_modes
118  def testNoVariables(self):
119    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
120      # pylint: disable=cell-var-from-loop
121      def loss():
122        var0 = resource_variable_ops.ResourceVariable(
123            [1.0, 2.0], dtype=dtype, trainable=False, name='a')
124        var1 = resource_variable_ops.ResourceVariable(
125            [3.0, 4.0], dtype=dtype, trainable=False, name='b')
126        return 5 * var0 + var1
127      # pylint: enable=cell-var-from-loop
128      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
129      with self.assertRaisesRegex(ValueError, 'No.*variables'):
130        sgd_op.minimize(loss)
131
132  @test_util.run_in_graph_and_eager_modes
133  def testNoGradients(self):
134    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
135      # Note that we name the variables uniquely here since the variables don't
136      # seem to be getting deleted at the end of the loop.
137      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
138                                                    name='a%d' % i)
139      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
140                                                    name='b%d' % i)
141      # pylint: disable=cell-var-from-loop
142      def loss():
143        return 5 * var0
144      # pylint: enable=cell-var-from-loop
145      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
146      with self.assertRaisesRegex(ValueError, 'No gradients'):
147        # var1 has no gradient
148        sgd_op.minimize(loss, var_list=[var1])
149
150  @test_util.run_in_graph_and_eager_modes
151  def testNoGradientsForAnyVariables_Minimize(self):
152    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
153      # Note that we name the variables uniquely here since the variables don't
154      # seem to be getting deleted at the end of the loop.
155      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
156                                                    name='a_%d' % i)
157      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
158                                                    name='b_%d' % i)
159      def loss():
160        return constant_op.constant(5.0)
161      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
162      with self.assertRaisesRegex(ValueError,
163                                  'No gradients provided for any variable'):
164        sgd_op.minimize(loss, var_list=[var0, var1])
165
166  @test_util.run_in_graph_and_eager_modes
167  def testNoGradientsForAnyVariables_ApplyGradients(self):
168    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
169      # Note that we name the variables uniquely here since the variables don't
170      # seem to be getting deleted at the end of the loop.
171      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
172                                                    name='a_%d' % i)
173      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
174                                                    name='b_%d' % i)
175      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
176      with self.assertRaisesRegex(ValueError,
177                                  'No gradients provided for any variable'):
178        sgd_op.apply_gradients([(None, var0), (None, var1)])
179
180  @test_util.run_in_graph_and_eager_modes
181  def testGradientsAsVariables(self):
182    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
183      # Note that we name the variables uniquely here since the variables don't
184      # seem to be getting deleted at the end of the loop.
185      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype,
186                                                    name='a%d' % i)
187      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype,
188                                                    name='b%d' % i)
189      def loss():
190        return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
191      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
192      grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
193      # Convert gradients to tf.Variables
194      converted_grads = [
195          resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
196                                                 name='c_%d_%d' % (i, j))
197          for j, gv in enumerate(grads_and_vars)
198      ]
199      convert_ops = [
200          state_ops.assign(converted_grads[j], gv[0])
201          for j, gv in enumerate(grads_and_vars)
202      ]
203
204      self.evaluate(variables.global_variables_initializer())
205      # Run convert_ops to achieve the gradients converting
206      self.evaluate(convert_ops)
207      # Fetch params to validate initial values
208      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
209      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
210
211      # Run 1 step of sgd through optimizer
212      converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
213      opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
214      self.evaluate(opt_op)
215
216      # Validate updated params
217      self.assertAllClose([-14., -13.], self.evaluate(var0))
218      self.assertAllClose([-6., -5.], self.evaluate(var1))
219
220  @test_util.run_in_graph_and_eager_modes
221  def testComputeGradientsWithTensors(self):
222    x = ops.convert_to_tensor(1.0)
223    def f():
224      return x * x
225    sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
226    grads_and_vars = sgd_op.compute_gradients(f, [x])
227    self.assertEqual(1, len(grads_and_vars))
228    grad, x_as_var = grads_and_vars[0]
229    self.assertIs(x, x_as_var)
230    self.assertEqual(2.0, self.evaluate(grad))
231
232    with self.assertRaises(NotImplementedError):
233      sgd_op.apply_gradients(grads_and_vars)
234
235  @test_util.run_deprecated_v1
236  def testTrainOp(self):
237    with self.cached_session():
238      var0 = variables.Variable([1.0, 2.0])
239      var1 = variables.Variable([3.0, 4.0])
240      cost = 5 * var0 + 3 * var1
241      global_step = variables.Variable(
242          array_ops.zeros([], dtypes.int64), name='global_step')
243      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
244      opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
245      self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
246
247  @test_util.run_deprecated_v1
248  def testConstraint(self):
249    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
250    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
251    with self.cached_session():
252      var0 = variables.Variable([1.0, 2.0],
253                                constraint=constraint_01)
254      var1 = variables.Variable([3.0, 4.0],
255                                constraint=constraint_0)
256      cost = 5 * var0 + 3 * var1
257      global_step = variables.Variable(
258          array_ops.zeros([], dtypes.int64), name='global_step')
259      sgd_op = gradient_descent.GradientDescentOptimizer(3.0)
260      opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
261
262      self.evaluate(variables.global_variables_initializer())
263      # Fetch params to validate initial values
264      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
265      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
266      # Run 1 step of sgd through optimizer
267      opt_op.run()
268      # Validate updated params
269      self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
270      self.assertAllClose([0., 0.], self.evaluate(var1))
271
272  @test_util.run_deprecated_v1
273  def testGetSlotUnderDistributedStrategy(self):
274    # Only run this test in graph mode so we don't need actual GPU.
275    ds = mirrored_strategy.MirroredStrategy(
276        ['CPU:0', 'GPU:0'],
277        cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce())
278    # We need an optimizer that creates slots.
279    optimizer = adam.AdamOptimizer()
280
281    def f():
282      v = variables.Variable([1.0])
283      self.assertTrue(distribute_utils.is_distributed_variable(v))
284      # Slot variables are created in the first call to apply_gradients.
285      optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)])
286      self.assertTrue(optimizer.get_slot_names())
287      for name in optimizer.get_slot_names():
288        slot = optimizer.get_slot(v, name)
289        self.assertIsNotNone(slot)
290        self.assertTrue(distribute_utils.is_distributed_variable(slot))
291
292    ds.run(f)
293
294
295if __name__ == '__main__':
296  test.main()
297