1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for make_template.""" 16import functools 17import traceback 18 19from tensorflow.python.client import session 20from tensorflow.python.eager import context 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import random_seed 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import init_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import template 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops import variables 30import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 31from tensorflow.python.platform import test 32from tensorflow.python.training import gradient_descent 33 34 35def variable_scoped_function(trainable=True): 36 return variable_scope.get_variable( 37 "dummy", shape=[1], trainable=trainable, 38 initializer=init_ops.zeros_initializer()) 39 40 41def internally_variable_scoped_function(scope_name): 42 with variable_scope.variable_scope(scope_name): 43 return variable_scope.get_variable( 44 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 45 46 47def function_with_create(trainable): 48 """Creates a variable as a side effect using tf.Variable.""" 49 variables.Variable(0, trainable=trainable) 50 return variable_scope.get_variable( 51 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 52 53 54def function_with_side_create(trainable, name="side"): 55 """Creates a variable as a side effect using tf.get_variable.""" 56 variable_scope.get_variable(name, shape=[1], trainable=trainable) 57 return variable_scope.get_variable( 58 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 59 60 61def variable_scoped_function_with_local_variable(): 62 variable_scope.get_local_variable( 63 "local", shape=[1], initializer=init_ops.zeros_initializer()) 64 return variable_scope.get_variable( 65 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 66 67 68class TemplateTest(test.TestCase): 69 70 @test_util.run_deprecated_v1 71 def test_end_to_end(self): 72 """This test shows a very simple line model with test_loss. 73 74 The template is used to share parameters between a training and test model. 75 """ 76 # y = 2x + 1 77 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 78 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 79 80 random_seed.set_random_seed(1234) 81 82 def test_line(x): 83 m = variable_scope.get_variable( 84 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 85 b = variable_scope.get_variable( 86 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 87 return x * m + b 88 89 line_template = template.make_template("line", test_line) 90 91 train_prediction = line_template(training_input) 92 test_prediction = line_template(test_input) 93 94 train_loss = math_ops.reduce_mean( 95 math_ops.square(train_prediction - training_output)) 96 test_loss = math_ops.reduce_mean( 97 math_ops.square(test_prediction - test_output)) 98 99 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 100 train_op = optimizer.minimize(train_loss) 101 102 with session.Session() as sess: 103 self.evaluate(variables.global_variables_initializer()) 104 initial_test_loss = self.evaluate(test_loss) 105 self.evaluate(train_op) 106 final_test_loss = self.evaluate(test_loss) 107 108 # Parameters are tied, so the loss should have gone down when we trained it. 109 self.assertLess(final_test_loss, initial_test_loss) 110 111 def test_end_to_end_eager(self): 112 """This test shows a very simple line model with test_loss in eager mode. 113 114 The template is used to share parameters between a training and test model. 115 """ 116 with context.eager_mode(): 117 # y = 2x + 1 118 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 119 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 120 121 random_seed.set_random_seed(1234) 122 123 def test_line(x): 124 m = variable_scope.get_variable( 125 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 126 b = variable_scope.get_variable( 127 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 128 return x * m + b 129 130 line_template = template.make_template("line", test_line) 131 132 def train_loss(): 133 train_prediction = line_template(training_input) 134 return math_ops.reduce_mean( 135 math_ops.square(train_prediction - training_output)) 136 137 def test_loss(): 138 test_prediction = line_template(test_input) 139 return math_ops.reduce_mean( 140 math_ops.square(test_prediction - test_output)) 141 142 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 143 initial_test_loss = test_loss() 144 optimizer.minimize(train_loss) 145 final_test_loss = test_loss() 146 147 # Parameters are tied, so the loss should have gone down after training. 148 self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) 149 150 def test_eager_delayed_store_pickup(self): 151 """This test shows a very simple line model with test_loss in eager mode. 152 153 The template is used to share parameters between a training and test model. 154 155 This test also shows that it can pick up explicitly set variable stores 156 even if they are only set before the first template usage. 157 """ 158 with context.eager_mode(): 159 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 160 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 161 162 random_seed.set_random_seed(1234) 163 164 def test_line(x): 165 m = variable_scope.get_variable( 166 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 167 b = variable_scope.get_variable( 168 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 169 return x * m + b 170 171 line_template = template.make_template("line", test_line) 172 173 def train_loss(): 174 train_prediction = line_template(training_input) 175 return math_ops.reduce_mean( 176 math_ops.square(train_prediction - training_output)) 177 178 def test_loss(): 179 test_prediction = line_template(test_input) 180 return math_ops.reduce_mean( 181 math_ops.square(test_prediction - test_output)) 182 183 store = variable_scope._VariableStore() 184 store._store_eager_variables = True 185 186 with variable_scope.with_variable_store(store): 187 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 188 initial_test_loss = test_loss() 189 optimizer.minimize(train_loss) 190 final_test_loss = test_loss() 191 192 # Parameters are tied, so the loss should have gone down after training. 193 self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) 194 195 # Verify that the explicitly set store is not empty 196 # and the make_template picked it up 197 self.assertEqual(set(store._vars.keys()), {"line/w", "line/b"}) 198 199 # But the store should only get picked up once, so a second 200 # store will go unused: 201 second_store = variable_scope._VariableStore() 202 second_store._store_eager_variables = True 203 204 with variable_scope.with_variable_store(second_store): 205 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 206 test_loss() 207 optimizer.minimize(train_loss) 208 test_loss() 209 self.assertEmpty(second_store._vars) 210 211 @test_util.run_in_graph_and_eager_modes 212 def test_skip_stack_frames(self): 213 first = traceback.format_stack() 214 second = traceback.format_stack() 215 result = template._skip_common_stack_elements(first, second) 216 self.assertEqual(1, len(result)) 217 self.assertNotEqual(len(first), len(result)) 218 219 @test_util.run_in_graph_and_eager_modes 220 def test_template_with_empty_name(self): 221 tpl = template.make_template("", variable_scoped_function) 222 with variable_scope.variable_scope("outer"): 223 x = variable_scope.get_variable("x", []) 224 v = tpl() 225 self.assertEqual("outer/", tpl.variable_scope_name) 226 self.assertEqual("outer//dummy:0", v.name) 227 if context.executing_eagerly(): 228 # In eager mode `x` is not visible to the template since the template does 229 # not rely on global collections. 230 self.assertEqual(1, len(tpl.variables)) 231 self.assertIs(v, tpl.variables[0]) 232 else: 233 self.assertEqual([x, v], tpl.variables) 234 235 @test_util.run_in_graph_and_eager_modes 236 def test_template_with_name(self): 237 tmpl1 = template.make_template("s1", variable_scoped_function) 238 tmpl2 = template.make_template("s1", variable_scoped_function) 239 240 v1 = tmpl1() 241 v2 = tmpl1() 242 v3 = tmpl2() 243 self.assertIs(v1, v2) 244 self.assertIsNot(v1, v3) 245 self.assertEqual("s1/dummy:0", v1.name) 246 self.assertEqual("s1_1/dummy:0", v3.name) 247 248 @test_util.run_deprecated_v1 249 def test_same_unique_name_raise_error(self): 250 tmpl1 = template.make_template( 251 "_", variable_scoped_function, unique_name_="s1") 252 tmpl1() 253 tmpl2 = template.make_template( 254 "_", variable_scoped_function, unique_name_="s1") 255 with self.assertRaisesRegex( 256 ValueError, "Variable s1/dummy already exists, disallowed.*"): 257 tmpl2() 258 259 def test_unique_name_raise_error_in_eager(self): 260 with context.eager_mode(): 261 with self.assertRaisesRegex( 262 ValueError, 263 "unique_name_ cannot be used when eager execution is enabled."): 264 template.make_template( 265 "_", variable_scoped_function, unique_name_="s1") 266 267 @test_util.run_deprecated_v1 268 def test_unique_name_and_reuse(self): 269 tmpl1 = template.make_template( 270 "_", variable_scoped_function, unique_name_="s1") 271 v1 = tmpl1() 272 v2 = tmpl1() 273 274 variable_scope.get_variable_scope().reuse_variables() 275 tmpl2 = template.make_template( 276 "_", variable_scoped_function, unique_name_="s1") 277 v3 = tmpl2() 278 279 self.assertIs(v1, v2) 280 self.assertIs(v1, v3) 281 self.assertEqual("s1/dummy:0", v1.name) 282 283 @test_util.run_in_graph_and_eager_modes 284 def test_template_in_scope(self): 285 tmpl1 = template.make_template("s1", variable_scoped_function) 286 tmpl2 = template.make_template("s1", variable_scoped_function) 287 288 with variable_scope.variable_scope("scope"): 289 v1 = tmpl1() 290 v3 = tmpl2() 291 292 # The template contract requires the following to ignore scope2. 293 with variable_scope.variable_scope("scope2"): 294 v2 = tmpl1() 295 self.assertIs(v1, v2) 296 self.assertIsNot(v1, v3) 297 self.assertEqual("scope/s1/dummy:0", v1.name) 298 self.assertEqual("scope/s1_1/dummy:0", v3.name) 299 300 @test_util.run_in_graph_and_eager_modes 301 def test_template_with_internal_reuse(self): 302 tmpl1 = template.make_template("s1", internally_variable_scoped_function) 303 tmpl2 = template.make_template("s1", internally_variable_scoped_function) 304 305 v1 = tmpl1("test") 306 v2 = tmpl1("test") 307 v3 = tmpl2("test") 308 self.assertIs(v1, v2) 309 self.assertIsNot(v1, v3) 310 self.assertEqual("s1/test/dummy:0", v1.name) 311 self.assertEqual("s1_1/test/dummy:0", v3.name) 312 313 with self.assertRaises(ValueError): 314 tmpl1("not_test") 315 316 @test_util.run_in_graph_and_eager_modes 317 def test_template_without_name(self): 318 with self.assertRaisesRegex(ValueError, "name cannot be None."): 319 template.make_template(None, variable_scoped_function) 320 321 @test_util.run_in_graph_and_eager_modes 322 def test_make_template(self): 323 # Test both that we can call it with positional and keywords. 324 tmpl1 = template.make_template( 325 "s1", internally_variable_scoped_function, scope_name="test") 326 tmpl2 = template.make_template( 327 "s1", internally_variable_scoped_function, scope_name="test") 328 329 v1 = tmpl1() 330 v2 = tmpl1() 331 v3 = tmpl2() 332 self.assertIs(v1, v2) 333 self.assertIsNot(v1, v3) 334 self.assertEqual("s1/test/dummy:0", v1.name) 335 self.assertEqual("s1_1/test/dummy:0", v3.name) 336 337 @test_util.run_deprecated_v1 338 def test_enforces_no_extra_trainable_variables(self): 339 tmpl = template.make_template("s", function_with_create, trainable=True) 340 341 tmpl() 342 with self.assertRaises(ValueError): 343 tmpl() 344 345 @test_util.run_in_graph_and_eager_modes 346 def test_enforces_no_extra_trainable_variables_eager(self): 347 tmpl = template.make_template("s", 348 function_with_side_create, 349 trainable=True) 350 351 tmpl(name="1") 352 with self.assertRaises(ValueError): 353 tmpl(name="2") 354 355 def test_permits_extra_non_trainable_variables(self): 356 tmpl = template.make_template("s", function_with_create, trainable=False) 357 self.assertIs(tmpl(), tmpl()) 358 359 def test_permits_extra_non_trainable_variables_eager(self): 360 with context.eager_mode(): 361 tmpl = template.make_template("s", 362 function_with_side_create, 363 trainable=False) 364 self.assertIs(tmpl(name="1"), tmpl(name="2")) 365 366 @test_util.run_in_graph_and_eager_modes 367 def test_internal_variable_reuse(self): 368 369 def nested(): 370 with variable_scope.variable_scope("nested") as vs: 371 v1 = variable_scope.get_variable( 372 "x", initializer=init_ops.zeros_initializer(), shape=[]) 373 with variable_scope.variable_scope(vs, reuse=True): 374 v2 = variable_scope.get_variable("x") 375 self.assertIs(v1, v2) 376 return v1 377 378 tmpl1 = template.make_template("s1", nested) 379 tmpl2 = template.make_template("s1", nested) 380 381 v1 = tmpl1() 382 v2 = tmpl1() 383 v3 = tmpl2() 384 self.assertIs(v1, v2) 385 self.assertIsNot(v1, v3) 386 self.assertEqual("s1/nested/x:0", v1.name) 387 self.assertEqual("s1_1/nested/x:0", v3.name) 388 389 @test_util.run_in_graph_and_eager_modes 390 def test_nested_templates(self): 391 392 def nested_template(): 393 nested1 = template.make_template("nested", variable_scoped_function) 394 nested2 = template.make_template("nested", variable_scoped_function) 395 v1 = nested1() 396 v2 = nested2() 397 398 # nested1 and nested2 should not share variables 399 self.assertIsNot(v1, v2) 400 401 # Variables created by nested1 should be isolated from variables 402 # created by nested2. 403 self.assertEqual(1, len(nested1.variables)) 404 self.assertEqual(1, len(nested2.variables)) 405 self.assertIs(nested1.variables[0], v1) 406 self.assertIs(nested2.variables[0], v2) 407 self.assertEqual(1, len(nested1.trainable_variables)) 408 self.assertEqual(1, len(nested2.trainable_variables)) 409 self.assertIs(nested1.trainable_variables[0], v1) 410 self.assertIs(nested2.trainable_variables[0], v2) 411 self.assertEqual(len(nested1.non_trainable_variables), 0) 412 self.assertEqual(len(nested2.non_trainable_variables), 0) 413 return v1, v2 414 415 tmpl1 = template.make_template("s1", nested_template) 416 tmpl2 = template.make_template("s1", nested_template) 417 418 v1, v2 = tmpl1() 419 v3, v4 = tmpl1() 420 v5, v6 = tmpl2() 421 422 # The second invocation of tmpl1 should reuse the variables 423 # created in the first invocation. 424 self.assertIs(v1, v3) 425 self.assertIs(v2, v4) 426 for v, w in zip(tmpl1.variables, [v1, v2]): 427 self.assertIs(v, w) 428 for v, w in zip(tmpl1.trainable_variables, [v1, v2]): 429 self.assertIs(v, w) 430 self.assertEqual(len(tmpl1.non_trainable_variables), 0) 431 432 # tmpl1 and tmpl2 should not share variables. 433 self.assertIsNot(v1, v5) 434 self.assertIsNot(v2, v6) 435 for v, w in zip(tmpl2.variables, [v5, v6]): 436 self.assertIs(v, w) 437 for v, w in zip(tmpl2.trainable_variables, [v5, v6]): 438 self.assertIs(v, w) 439 self.assertEqual(len(tmpl2.non_trainable_variables), 0) 440 self.assertEqual("s1/nested/dummy:0", v1.name) 441 self.assertEqual("s1/nested_1/dummy:0", v2.name) 442 self.assertEqual("s1_1/nested/dummy:0", v5.name) 443 self.assertEqual("s1_1/nested_1/dummy:0", v6.name) 444 445 self.assertEqual(["nested", "nested_1"], list(tmpl1._trackable_children())) 446 447 @test_util.run_in_graph_and_eager_modes 448 def test_nested_templates_with_defun(self): 449 450 def variable_scoped_function_no_return_value(trainable=True): 451 # defun cannot compile functions that return non-Tensor objects 452 _ = variable_scope.get_variable( 453 "dummy", 454 shape=[1], 455 trainable=trainable, 456 initializer=init_ops.zeros_initializer()) 457 458 def nested_template(): 459 nested1 = template.make_template_internal( 460 "nested", 461 variable_scoped_function_no_return_value, 462 create_graph_function_=True) 463 nested2 = template.make_template_internal( 464 "nested", 465 variable_scoped_function_no_return_value, 466 create_graph_function_=True) 467 nested1() 468 nested2() 469 v1 = nested1.variables 470 v2 = nested2.variables 471 472 self.assertEqual(len(v1), 1) 473 self.assertEqual(len(v2), 1) 474 475 # nested1 and nested2 should not share variables 476 self.assertIsNot(v1[0], v2[0]) 477 self.assertIs(nested1.trainable_variables[0], v1[0]) 478 self.assertIs(nested2.trainable_variables[0], v2[0]) 479 self.assertEqual(len(nested1.non_trainable_variables), 0) 480 self.assertEqual(len(nested2.non_trainable_variables), 0) 481 482 tmpl1 = template.make_template("s1", nested_template) 483 tmpl2 = template.make_template("s1", nested_template) 484 485 tmpl1() 486 v1 = tmpl1.variables 487 tmpl1() 488 v2 = tmpl1.variables 489 tmpl2() 490 v3 = tmpl2.variables 491 492 # The second invocation of tmpl1 should reuse the variables 493 # created in the first invocation. 494 for v, w in zip(v1, v2): 495 self.assertIs(v, w) 496 497 # tmpl1 and tmpl2 should not share variables. 498 for v, w in zip(v1, v3): 499 self.assertIsNot(v, w) 500 501 self.assertEqual("s1/nested/dummy:0", v1[0].name) 502 self.assertEqual("s1/nested_1/dummy:0", v1[1].name) 503 self.assertEqual("s1_1/nested/dummy:0", v3[0].name) 504 self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name) 505 506 def test_graph_function_no_name(self): 507 with context.eager_mode(): 508 509 def f(_, y): 510 return y + 1 511 512 partial = functools.partial(f, 1.0) 513 tmpl = template.make_template_internal( 514 "a", partial, create_graph_function_=True) 515 self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) 516 517 @test_util.run_in_graph_and_eager_modes 518 def test_immediate_scope_creation(self): 519 # Create templates in scope a then call in scope b. make_template should 520 # capture the scope the first time it is called, and make_immediate_template 521 # should capture the scope at construction time. 522 with variable_scope.variable_scope("ctor_scope"): 523 # Create scope here: 524 tmpl_immed = template.make_template("a", variable_scoped_function, 525 True) 526 # default: create scope at __call__ 527 tmpl_defer = template.make_template( 528 "b", variable_scoped_function, False) 529 with variable_scope.variable_scope("call_scope"): 530 inner_imm_var = tmpl_immed() 531 inner_defer_var = tmpl_defer() 532 outer_imm_var = tmpl_immed() 533 outer_defer_var = tmpl_defer() 534 535 self.assertIsNot(inner_imm_var, inner_defer_var) 536 self.assertIs(outer_imm_var, inner_imm_var) 537 self.assertIs(outer_defer_var, inner_defer_var) 538 539 self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) 540 self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) 541 542 @test_util.run_in_graph_and_eager_modes 543 def test_scope_access(self): 544 # Ensure that we can access the scope inside the template, because the name 545 # of that scope may be different from the name we pass to make_template, due 546 # to having been made unique by variable_scope. 547 with variable_scope.variable_scope("foo"): 548 # Create two templates with the same name, ensure scopes are made unique. 549 ta = template.make_template("bar", variable_scoped_function, True) 550 tb = template.make_template("bar", variable_scoped_function, True) 551 552 # Ensure we can get the scopes before either template is actually called. 553 self.assertEqual(ta.variable_scope.name, "foo/bar") 554 self.assertEqual(tb.variable_scope.name, "foo/bar_1") 555 556 with variable_scope.variable_scope("foo_2"): 557 # Create a template which defers scope creation. 558 tc = template.make_template("blah", variable_scoped_function, False) 559 560 # Before we call the template, the scope property will be set to None. 561 self.assertEqual(tc.variable_scope, None) 562 tc() 563 564 # Template is called at the top level, so there is no preceding "foo_2". 565 self.assertEqual(tc.variable_scope.name, "blah") 566 567 @test_util.run_in_graph_and_eager_modes 568 def test_custom_getter(self): 569 # Custom getter that maintains call count and forwards to true getter 570 custom_getter_count = [0] 571 572 def custom_getter(getter, name, *args, **kwargs): 573 custom_getter_count[0] += 1 574 return getter(name, *args, **kwargs) 575 576 # Test that custom getter is called both when variables are created and 577 # subsequently accessed 578 tmpl1 = template.make_template( 579 "s1", variable_scoped_function, custom_getter_=custom_getter) 580 self.assertEqual(custom_getter_count[0], 0) 581 tmpl1() 582 self.assertEqual(custom_getter_count[0], 1) 583 tmpl1() 584 self.assertEqual(custom_getter_count[0], 2) 585 586 # Test that custom getter is called when the variable scope is created 587 # during construction 588 custom_getter_count[0] = 0 589 tmpl2 = template.make_template( 590 "s2", 591 variable_scoped_function, 592 custom_getter_=custom_getter, 593 create_scope_now_=True) 594 self.assertEqual(custom_getter_count[0], 0) 595 tmpl2() 596 self.assertEqual(custom_getter_count[0], 1) 597 tmpl2() 598 self.assertEqual(custom_getter_count[0], 2) 599 600 @test_util.run_in_graph_and_eager_modes 601 def test_fails_gracefully(self): 602 for create_scope_now in [True, False]: 603 def module_function_with_one_arg(inputs): 604 w = variable_scope.get_variable( 605 "w", shape=[1], initializer=init_ops.zeros_initializer()) 606 return inputs * w 607 608 templatized_function = template.make_template( 609 "f1", module_function_with_one_arg, 610 create_scope_now_=create_scope_now) 611 data = array_ops.zeros([1]) 612 try: 613 # Try to connect with a kwarg which is unsupported. 614 templatized_function(data, is_training=True) 615 except TypeError: 616 pass 617 618 # The failed __call__ hasn't modified the inner state. 619 self.assertFalse(templatized_function._variables_created) 620 templatized_function(data) 621 self.assertTrue(templatized_function._variables_created) 622 623 @test_util.run_in_graph_and_eager_modes 624 def test_name_scopes_for_variable_scopes(self): 625 # Test that name scopes are not unnecessarily uniquified (but are 626 # still uniquified when necessary). 627 def linear_module(x, output_size): 628 w = variable_scope.get_variable( 629 "w", shape=[x.get_shape()[1], output_size], 630 initializer=init_ops.zeros_initializer()) 631 b = variable_scope.get_variable( 632 "b", shape=[output_size], 633 initializer=init_ops.zeros_initializer()) 634 return (math_ops.matmul(x, w) + b), w 635 636 def make_linear_module(output_size, name): 637 return template.make_template( 638 name, 639 linear_module, 640 output_size=output_size, 641 create_scope_now_=True) 642 643 inputs = array_ops.ones((3, 4)) 644 645 linear1 = make_linear_module(output_size=2, name="foo") 646 outputs_a, w1 = linear1(inputs) 647 outputs_b, _ = linear1(inputs) 648 self.assertEqual("foo", linear1.variable_scope.name) 649 self.assertEqual("foo/w:0", w1.name) 650 if not context.executing_eagerly(): 651 self.assertEqual( 652 "foo/add:0", outputs_a.name, 653 "First application of template should get " 654 "same name scope as variables.") 655 self.assertEqual( 656 "foo_1/add:0", outputs_b.name, 657 "Second application of template should get " 658 "a freshly uniquified name scope.") 659 660 linear2 = make_linear_module(output_size=2, name="foo") 661 outputs_c, w2 = linear2(inputs) 662 outputs_d, _ = linear2(inputs) 663 self.assertEqual( 664 "foo_1", linear2.variable_scope.name, 665 "New template gets a freshly uniquified variable scope " 666 "because 'foo' is already taken.") 667 self.assertEqual("foo_1/w:0", w2.name) 668 if not context.executing_eagerly(): 669 self.assertEqual( 670 "foo_1_1/add:0", outputs_c.name, 671 "First application of template would get " 672 "same name scope as variables, but 'foo_1' is already " 673 "a name scope.") 674 self.assertEqual( 675 "foo_1_2/add:0", outputs_d.name, 676 "Second application of template should also get " 677 "a freshly uniquified name scope.") 678 679 @test_util.run_in_graph_and_eager_modes 680 def test_global_variables(self): 681 # Make sure global_variables are created. 682 with variable_scope.variable_scope("foo"): 683 # Create two templates with the same name, ensure scopes are made unique. 684 ta = template.make_template("bar", variable_scoped_function, True) 685 if context.executing_eagerly(): 686 tb = template.make_template("s", function_with_side_create, 687 trainable=False) 688 else: 689 tb = template.make_template("s", function_with_create, trainable=False) 690 691 # Initially there are not variables created. 692 self.assertEqual([], list(ta.global_variables)) 693 self.assertEqual([], list(tb.global_variables)) 694 # After calling there are variables created. 695 ta() 696 tb() 697 # Ensure we can get the scopes before either template is actually called. 698 self.assertEqual(1, len(ta.global_variables)) 699 self.assertEqual(2, len(tb.global_variables)) 700 701 @test_util.run_in_graph_and_eager_modes 702 def test_trainable_variables(self): 703 # Make sure trainable_variables are created. 704 with variable_scope.variable_scope("foo2"): 705 # Create two templates with the same name, ensure scopes are made unique. 706 ta = template.make_template("bar", variable_scoped_function, True) 707 tb = template.make_template("bar", variable_scoped_function, True) 708 709 # Initially there are not variables created. 710 self.assertEqual([], list(ta.trainable_variables)) 711 self.assertEqual([], list(tb.trainable_variables)) 712 # After calling there are variables created. 713 ta() 714 tb() 715 # Ensure we can get the scopes before either template is actually called. 716 self.assertEqual(1, len(ta.trainable_variables)) 717 self.assertEqual(1, len(tb.trainable_variables)) 718 # None non-trainable variable was created. 719 self.assertEqual([], list(ta.non_trainable_variables)) 720 self.assertEqual([], list(tb.non_trainable_variables)) 721 # Ensure variables returns all the variables. 722 self.assertEqual(1, len(ta.variables)) 723 self.assertEqual(1, len(tb.variables)) 724 725 @test_util.run_in_graph_and_eager_modes 726 def test_non_trainable_variables(self): 727 # Make sure non_trainable_variables are created. 728 with variable_scope.variable_scope("foo2"): 729 ta = template.make_template("a", variable_scoped_function, 730 trainable=True) 731 tb = template.make_template("b", variable_scoped_function, 732 trainable=False) 733 # Initially there are not variables created. 734 self.assertEqual([], list(ta.variables)) 735 self.assertEqual([], list(tb.variables)) 736 # After calling there are variables created. 737 ta() 738 tb() 739 # Check the trainable and non_trainable variables. 740 self.assertEqual(1, len(ta.trainable_variables)) 741 self.assertEqual([], list(ta.non_trainable_variables)) 742 743 self.assertEqual([], list(tb.trainable_variables)) 744 self.assertEqual(1, len(tb.non_trainable_variables)) 745 # Ensure variables returns all the variables. 746 self.assertEqual(1, len(ta.variables)) 747 self.assertEqual(1, len(tb.variables)) 748 749 # TODO(apassos) handle local variables in Eager 750 @test_util.run_deprecated_v1 751 def test_local_variables(self): 752 # Make sure trainable_variables are created. 753 with variable_scope.variable_scope("foo3"): 754 # Create two templates with the same name, ensure scopes are made unique. 755 ta = template.make_template("bar", variable_scoped_function, True) 756 tb = template.make_template("bar", 757 variable_scoped_function_with_local_variable) 758 759 # Initially there are not variables created. 760 self.assertEqual([], list(ta.local_variables)) 761 self.assertEqual([], list(tb.local_variables)) 762 # After calling there are variables created. 763 ta() 764 tb() 765 # Ensure we can get the scopes before either template is actually called. 766 self.assertEqual(0, len(ta.local_variables)) 767 self.assertEqual(1, len(tb.local_variables)) 768 769 @test_util.run_in_graph_and_eager_modes 770 def test_make_template_with_defun(self): 771 772 def variable_scoped_function_no_return_value(scope_name): 773 # defun cannot compile functions that return non-Tensor objects 774 with variable_scope.variable_scope(scope_name): 775 _ = variable_scope.get_variable( 776 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 777 778 tmpl = template.make_template_internal( 779 "s1", 780 variable_scoped_function_no_return_value, 781 create_graph_function_=True, 782 scope_name="test") 783 784 # The first invocation of tmpl1 creates variables, the second should 785 # be executed as a graph function. 786 tmpl() 787 v1 = tmpl.variables 788 tmpl() 789 v2 = tmpl.variables 790 791 self.assertEqual(len(v1), len(v2)) 792 for v, w in zip(v1, v2): 793 self.assertIs(v, w) 794 self.assertEqual("s1/test/dummy:0", v1[0].name) 795 796 797if __name__ == "__main__": 798 test.main() 799