1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional test for optimizer.""" 16 17from tensorflow.python.distribute import cross_device_ops 18from tensorflow.python.distribute import distribute_utils 19from tensorflow.python.distribute import mirrored_strategy 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import clip_ops 26from tensorflow.python.ops import gradients_util 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import test 31from tensorflow.python.training import adam 32from tensorflow.python.training import gradient_descent 33 34 35class OptimizerTest(test.TestCase): 36 37 @test_util.run_in_graph_and_eager_modes 38 def testBasic(self): 39 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 40 # Note that we name the variables uniquely here since the variables don't 41 # seem to be getting deleted at the end of the loop. 42 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, 43 name='a_%d' % i) 44 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, 45 name='b_%d' % i) 46 def loss(): 47 return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop 48 # Note that for eager execution, minimize expects a function instead of a 49 # Tensor. 50 global_step = resource_variable_ops.ResourceVariable( 51 array_ops.zeros([], dtypes.int64), name='global_step_%d' % i) 52 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 53 54 self.evaluate(variables.global_variables_initializer()) 55 # Fetch params to validate initial values 56 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 57 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 58 # Run 1 step of sgd through optimizer 59 opt_op = sgd_op.minimize(loss, global_step, [var0, var1]) 60 self.evaluate(opt_op) 61 # Validate updated params 62 self.assertAllClose([-14., -13.], self.evaluate(var0)) 63 self.assertAllClose([-6., -5.], self.evaluate(var1)) 64 65 @test_util.run_deprecated_v1 66 def testAggregationMethod(self): 67 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 68 with self.cached_session(): 69 var0 = variables.Variable([1.0, 2.0], dtype=dtype) 70 var1 = variables.Variable([3.0, 4.0], dtype=dtype) 71 cost = 5 * var0 + 3 * var1 72 global_step = variables.Variable( 73 array_ops.zeros([], dtypes.int64), name='global_step') 74 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 75 opt_op = sgd_op.minimize( 76 cost, 77 global_step, [var0, var1], 78 aggregation_method=gradients_util.AggregationMethod. 79 EXPERIMENTAL_ACCUMULATE_N) 80 81 self.evaluate(variables.global_variables_initializer()) 82 # Fetch params to validate initial values 83 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 84 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 85 # Run 1 step of sgd through optimizer 86 opt_op.run() 87 # Validate updated params 88 self.assertAllClose([-14., -13.], self.evaluate(var0)) 89 self.assertAllClose([-6., -5.], self.evaluate(var1)) 90 91 @test_util.run_deprecated_v1 92 def testPrecomputedGradient(self): 93 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 94 with self.cached_session(): 95 var0 = variables.Variable([1.0, 2.0], dtype=dtype) 96 var1 = variables.Variable([3.0, 4.0], dtype=dtype) 97 cost = 5 * var0 + 3 * var1 98 grad_loss = constant_op.constant([42, -42], dtype=dtype) 99 global_step = variables.Variable( 100 array_ops.zeros([], dtypes.int64), name='global_step') 101 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 102 opt_op = sgd_op.minimize( 103 cost, global_step, [var0, var1], grad_loss=grad_loss) 104 105 self.evaluate(variables.global_variables_initializer()) 106 # Fetch params to validate initial values 107 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 108 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 109 # Run 1 step of sgd through optimizer 110 opt_op.run() 111 # Validate updated params 112 self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)], 113 self.evaluate(var0)) 114 self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], 115 self.evaluate(var1)) 116 117 @test_util.run_in_graph_and_eager_modes 118 def testNoVariables(self): 119 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: 120 # pylint: disable=cell-var-from-loop 121 def loss(): 122 var0 = resource_variable_ops.ResourceVariable( 123 [1.0, 2.0], dtype=dtype, trainable=False, name='a') 124 var1 = resource_variable_ops.ResourceVariable( 125 [3.0, 4.0], dtype=dtype, trainable=False, name='b') 126 return 5 * var0 + var1 127 # pylint: enable=cell-var-from-loop 128 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 129 with self.assertRaisesRegex(ValueError, 'No.*variables'): 130 sgd_op.minimize(loss) 131 132 @test_util.run_in_graph_and_eager_modes 133 def testNoGradients(self): 134 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 135 # Note that we name the variables uniquely here since the variables don't 136 # seem to be getting deleted at the end of the loop. 137 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, 138 name='a%d' % i) 139 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, 140 name='b%d' % i) 141 # pylint: disable=cell-var-from-loop 142 def loss(): 143 return 5 * var0 144 # pylint: enable=cell-var-from-loop 145 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 146 with self.assertRaisesRegex(ValueError, 'No gradients'): 147 # var1 has no gradient 148 sgd_op.minimize(loss, var_list=[var1]) 149 150 @test_util.run_in_graph_and_eager_modes 151 def testNoGradientsForAnyVariables_Minimize(self): 152 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 153 # Note that we name the variables uniquely here since the variables don't 154 # seem to be getting deleted at the end of the loop. 155 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, 156 name='a_%d' % i) 157 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, 158 name='b_%d' % i) 159 def loss(): 160 return constant_op.constant(5.0) 161 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 162 with self.assertRaisesRegex(ValueError, 163 'No gradients provided for any variable'): 164 sgd_op.minimize(loss, var_list=[var0, var1]) 165 166 @test_util.run_in_graph_and_eager_modes 167 def testNoGradientsForAnyVariables_ApplyGradients(self): 168 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 169 # Note that we name the variables uniquely here since the variables don't 170 # seem to be getting deleted at the end of the loop. 171 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, 172 name='a_%d' % i) 173 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, 174 name='b_%d' % i) 175 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 176 with self.assertRaisesRegex(ValueError, 177 'No gradients provided for any variable'): 178 sgd_op.apply_gradients([(None, var0), (None, var1)]) 179 180 @test_util.run_in_graph_and_eager_modes 181 def testGradientsAsVariables(self): 182 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 183 # Note that we name the variables uniquely here since the variables don't 184 # seem to be getting deleted at the end of the loop. 185 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, 186 name='a%d' % i) 187 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, 188 name='b%d' % i) 189 def loss(): 190 return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop 191 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 192 grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1]) 193 # Convert gradients to tf.Variables 194 converted_grads = [ 195 resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype), 196 name='c_%d_%d' % (i, j)) 197 for j, gv in enumerate(grads_and_vars) 198 ] 199 convert_ops = [ 200 state_ops.assign(converted_grads[j], gv[0]) 201 for j, gv in enumerate(grads_and_vars) 202 ] 203 204 self.evaluate(variables.global_variables_initializer()) 205 # Run convert_ops to achieve the gradients converting 206 self.evaluate(convert_ops) 207 # Fetch params to validate initial values 208 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 209 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 210 211 # Run 1 step of sgd through optimizer 212 converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) 213 opt_op = sgd_op.apply_gradients(converted_grads_and_vars) 214 self.evaluate(opt_op) 215 216 # Validate updated params 217 self.assertAllClose([-14., -13.], self.evaluate(var0)) 218 self.assertAllClose([-6., -5.], self.evaluate(var1)) 219 220 @test_util.run_in_graph_and_eager_modes 221 def testComputeGradientsWithTensors(self): 222 x = ops.convert_to_tensor(1.0) 223 def f(): 224 return x * x 225 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 226 grads_and_vars = sgd_op.compute_gradients(f, [x]) 227 self.assertEqual(1, len(grads_and_vars)) 228 grad, x_as_var = grads_and_vars[0] 229 self.assertIs(x, x_as_var) 230 self.assertEqual(2.0, self.evaluate(grad)) 231 232 with self.assertRaises(NotImplementedError): 233 sgd_op.apply_gradients(grads_and_vars) 234 235 @test_util.run_deprecated_v1 236 def testTrainOp(self): 237 with self.cached_session(): 238 var0 = variables.Variable([1.0, 2.0]) 239 var1 = variables.Variable([3.0, 4.0]) 240 cost = 5 * var0 + 3 * var1 241 global_step = variables.Variable( 242 array_ops.zeros([], dtypes.int64), name='global_step') 243 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 244 opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) 245 self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP)) 246 247 @test_util.run_deprecated_v1 248 def testConstraint(self): 249 constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) 250 constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) 251 with self.cached_session(): 252 var0 = variables.Variable([1.0, 2.0], 253 constraint=constraint_01) 254 var1 = variables.Variable([3.0, 4.0], 255 constraint=constraint_0) 256 cost = 5 * var0 + 3 * var1 257 global_step = variables.Variable( 258 array_ops.zeros([], dtypes.int64), name='global_step') 259 sgd_op = gradient_descent.GradientDescentOptimizer(3.0) 260 opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) 261 262 self.evaluate(variables.global_variables_initializer()) 263 # Fetch params to validate initial values 264 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 265 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 266 # Run 1 step of sgd through optimizer 267 opt_op.run() 268 # Validate updated params 269 self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) 270 self.assertAllClose([0., 0.], self.evaluate(var1)) 271 272 @test_util.run_deprecated_v1 273 def testGetSlotUnderDistributedStrategy(self): 274 # Only run this test in graph mode so we don't need actual GPU. 275 ds = mirrored_strategy.MirroredStrategy( 276 ['CPU:0', 'GPU:0'], 277 cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce()) 278 # We need an optimizer that creates slots. 279 optimizer = adam.AdamOptimizer() 280 281 def f(): 282 v = variables.Variable([1.0]) 283 self.assertTrue(distribute_utils.is_distributed_variable(v)) 284 # Slot variables are created in the first call to apply_gradients. 285 optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)]) 286 self.assertTrue(optimizer.get_slot_names()) 287 for name in optimizer.get_slot_names(): 288 slot = optimizer.get_slot(v, name) 289 self.assertIsNotNone(slot) 290 self.assertTrue(distribute_utils.is_distributed_variable(slot)) 291 292 ds.run(f) 293 294 295if __name__ == '__main__': 296 test.main() 297