xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/template_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for make_template."""
16import functools
17import traceback
18
19from tensorflow.python.client import session
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import random_seed
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import init_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import template
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops import variables
30import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
31from tensorflow.python.platform import test
32from tensorflow.python.training import gradient_descent
33
34
35def variable_scoped_function(trainable=True):
36  return variable_scope.get_variable(
37      "dummy", shape=[1], trainable=trainable,
38      initializer=init_ops.zeros_initializer())
39
40
41def internally_variable_scoped_function(scope_name):
42  with variable_scope.variable_scope(scope_name):
43    return variable_scope.get_variable(
44        "dummy", shape=[1], initializer=init_ops.zeros_initializer())
45
46
47def function_with_create(trainable):
48  """Creates a variable as a side effect using tf.Variable."""
49  variables.Variable(0, trainable=trainable)
50  return variable_scope.get_variable(
51      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
52
53
54def function_with_side_create(trainable, name="side"):
55  """Creates a variable as a side effect using tf.get_variable."""
56  variable_scope.get_variable(name, shape=[1], trainable=trainable)
57  return variable_scope.get_variable(
58      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
59
60
61def variable_scoped_function_with_local_variable():
62  variable_scope.get_local_variable(
63      "local", shape=[1], initializer=init_ops.zeros_initializer())
64  return variable_scope.get_variable(
65      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
66
67
68class TemplateTest(test.TestCase):
69
70  @test_util.run_deprecated_v1
71  def test_end_to_end(self):
72    """This test shows a very simple line model with test_loss.
73
74    The template is used to share parameters between a training and test model.
75    """
76    # y = 2x + 1
77    training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
78    test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
79
80    random_seed.set_random_seed(1234)
81
82    def test_line(x):
83      m = variable_scope.get_variable(
84          "w", shape=[], initializer=init_ops.truncated_normal_initializer())
85      b = variable_scope.get_variable(
86          "b", shape=[], initializer=init_ops.truncated_normal_initializer())
87      return x * m + b
88
89    line_template = template.make_template("line", test_line)
90
91    train_prediction = line_template(training_input)
92    test_prediction = line_template(test_input)
93
94    train_loss = math_ops.reduce_mean(
95        math_ops.square(train_prediction - training_output))
96    test_loss = math_ops.reduce_mean(
97        math_ops.square(test_prediction - test_output))
98
99    optimizer = gradient_descent.GradientDescentOptimizer(0.1)
100    train_op = optimizer.minimize(train_loss)
101
102    with session.Session() as sess:
103      self.evaluate(variables.global_variables_initializer())
104      initial_test_loss = self.evaluate(test_loss)
105      self.evaluate(train_op)
106      final_test_loss = self.evaluate(test_loss)
107
108    # Parameters are tied, so the loss should have gone down when we trained it.
109    self.assertLess(final_test_loss, initial_test_loss)
110
111  def test_end_to_end_eager(self):
112    """This test shows a very simple line model with test_loss in eager mode.
113
114    The template is used to share parameters between a training and test model.
115    """
116    with context.eager_mode():
117      # y = 2x + 1
118      training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
119      test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
120
121      random_seed.set_random_seed(1234)
122
123      def test_line(x):
124        m = variable_scope.get_variable(
125            "w", shape=[], initializer=init_ops.truncated_normal_initializer())
126        b = variable_scope.get_variable(
127            "b", shape=[], initializer=init_ops.truncated_normal_initializer())
128        return x * m + b
129
130      line_template = template.make_template("line", test_line)
131
132      def train_loss():
133        train_prediction = line_template(training_input)
134        return math_ops.reduce_mean(
135            math_ops.square(train_prediction - training_output))
136
137      def test_loss():
138        test_prediction = line_template(test_input)
139        return math_ops.reduce_mean(
140            math_ops.square(test_prediction - test_output))
141
142      optimizer = gradient_descent.GradientDescentOptimizer(0.1)
143      initial_test_loss = test_loss()
144      optimizer.minimize(train_loss)
145      final_test_loss = test_loss()
146
147      # Parameters are tied, so the loss should have gone down after training.
148      self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())
149
150  def test_eager_delayed_store_pickup(self):
151    """This test shows a very simple line model with test_loss in eager mode.
152
153    The template is used to share parameters between a training and test model.
154
155    This test also shows that it can pick up explicitly set variable stores
156    even if they are only set before the first template usage.
157    """
158    with context.eager_mode():
159      training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
160      test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
161
162      random_seed.set_random_seed(1234)
163
164      def test_line(x):
165        m = variable_scope.get_variable(
166            "w", shape=[], initializer=init_ops.truncated_normal_initializer())
167        b = variable_scope.get_variable(
168            "b", shape=[], initializer=init_ops.truncated_normal_initializer())
169        return x * m + b
170
171      line_template = template.make_template("line", test_line)
172
173      def train_loss():
174        train_prediction = line_template(training_input)
175        return math_ops.reduce_mean(
176            math_ops.square(train_prediction - training_output))
177
178      def test_loss():
179        test_prediction = line_template(test_input)
180        return math_ops.reduce_mean(
181            math_ops.square(test_prediction - test_output))
182
183      store = variable_scope._VariableStore()
184      store._store_eager_variables = True
185
186      with variable_scope.with_variable_store(store):
187        optimizer = gradient_descent.GradientDescentOptimizer(0.1)
188        initial_test_loss = test_loss()
189        optimizer.minimize(train_loss)
190        final_test_loss = test_loss()
191
192        # Parameters are tied, so the loss should have gone down after training.
193        self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())
194
195      # Verify that the explicitly set store is not empty
196      # and the make_template picked it up
197      self.assertEqual(set(store._vars.keys()), {"line/w", "line/b"})
198
199      # But the store should only get picked up once, so a second
200      # store will go unused:
201      second_store = variable_scope._VariableStore()
202      second_store._store_eager_variables = True
203
204      with variable_scope.with_variable_store(second_store):
205        optimizer = gradient_descent.GradientDescentOptimizer(0.1)
206        test_loss()
207        optimizer.minimize(train_loss)
208        test_loss()
209      self.assertEmpty(second_store._vars)
210
211  @test_util.run_in_graph_and_eager_modes
212  def test_skip_stack_frames(self):
213    first = traceback.format_stack()
214    second = traceback.format_stack()
215    result = template._skip_common_stack_elements(first, second)
216    self.assertEqual(1, len(result))
217    self.assertNotEqual(len(first), len(result))
218
219  @test_util.run_in_graph_and_eager_modes
220  def test_template_with_empty_name(self):
221    tpl = template.make_template("", variable_scoped_function)
222    with variable_scope.variable_scope("outer"):
223      x = variable_scope.get_variable("x", [])
224      v = tpl()
225    self.assertEqual("outer/", tpl.variable_scope_name)
226    self.assertEqual("outer//dummy:0", v.name)
227    if context.executing_eagerly():
228      # In eager mode `x` is not visible to the template since the template does
229      # not rely on global collections.
230      self.assertEqual(1, len(tpl.variables))
231      self.assertIs(v, tpl.variables[0])
232    else:
233      self.assertEqual([x, v], tpl.variables)
234
235  @test_util.run_in_graph_and_eager_modes
236  def test_template_with_name(self):
237    tmpl1 = template.make_template("s1", variable_scoped_function)
238    tmpl2 = template.make_template("s1", variable_scoped_function)
239
240    v1 = tmpl1()
241    v2 = tmpl1()
242    v3 = tmpl2()
243    self.assertIs(v1, v2)
244    self.assertIsNot(v1, v3)
245    self.assertEqual("s1/dummy:0", v1.name)
246    self.assertEqual("s1_1/dummy:0", v3.name)
247
248  @test_util.run_deprecated_v1
249  def test_same_unique_name_raise_error(self):
250    tmpl1 = template.make_template(
251        "_", variable_scoped_function, unique_name_="s1")
252    tmpl1()
253    tmpl2 = template.make_template(
254        "_", variable_scoped_function, unique_name_="s1")
255    with self.assertRaisesRegex(
256        ValueError, "Variable s1/dummy already exists, disallowed.*"):
257      tmpl2()
258
259  def test_unique_name_raise_error_in_eager(self):
260    with context.eager_mode():
261      with self.assertRaisesRegex(
262          ValueError,
263          "unique_name_ cannot be used when eager execution is enabled."):
264        template.make_template(
265            "_", variable_scoped_function, unique_name_="s1")
266
267  @test_util.run_deprecated_v1
268  def test_unique_name_and_reuse(self):
269    tmpl1 = template.make_template(
270        "_", variable_scoped_function, unique_name_="s1")
271    v1 = tmpl1()
272    v2 = tmpl1()
273
274    variable_scope.get_variable_scope().reuse_variables()
275    tmpl2 = template.make_template(
276        "_", variable_scoped_function, unique_name_="s1")
277    v3 = tmpl2()
278
279    self.assertIs(v1, v2)
280    self.assertIs(v1, v3)
281    self.assertEqual("s1/dummy:0", v1.name)
282
283  @test_util.run_in_graph_and_eager_modes
284  def test_template_in_scope(self):
285    tmpl1 = template.make_template("s1", variable_scoped_function)
286    tmpl2 = template.make_template("s1", variable_scoped_function)
287
288    with variable_scope.variable_scope("scope"):
289      v1 = tmpl1()
290      v3 = tmpl2()
291
292    # The template contract requires the following to ignore scope2.
293    with variable_scope.variable_scope("scope2"):
294      v2 = tmpl1()
295    self.assertIs(v1, v2)
296    self.assertIsNot(v1, v3)
297    self.assertEqual("scope/s1/dummy:0", v1.name)
298    self.assertEqual("scope/s1_1/dummy:0", v3.name)
299
300  @test_util.run_in_graph_and_eager_modes
301  def test_template_with_internal_reuse(self):
302    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
303    tmpl2 = template.make_template("s1", internally_variable_scoped_function)
304
305    v1 = tmpl1("test")
306    v2 = tmpl1("test")
307    v3 = tmpl2("test")
308    self.assertIs(v1, v2)
309    self.assertIsNot(v1, v3)
310    self.assertEqual("s1/test/dummy:0", v1.name)
311    self.assertEqual("s1_1/test/dummy:0", v3.name)
312
313    with self.assertRaises(ValueError):
314      tmpl1("not_test")
315
316  @test_util.run_in_graph_and_eager_modes
317  def test_template_without_name(self):
318    with self.assertRaisesRegex(ValueError, "name cannot be None."):
319      template.make_template(None, variable_scoped_function)
320
321  @test_util.run_in_graph_and_eager_modes
322  def test_make_template(self):
323    # Test both that we can call it with positional and keywords.
324    tmpl1 = template.make_template(
325        "s1", internally_variable_scoped_function, scope_name="test")
326    tmpl2 = template.make_template(
327        "s1", internally_variable_scoped_function, scope_name="test")
328
329    v1 = tmpl1()
330    v2 = tmpl1()
331    v3 = tmpl2()
332    self.assertIs(v1, v2)
333    self.assertIsNot(v1, v3)
334    self.assertEqual("s1/test/dummy:0", v1.name)
335    self.assertEqual("s1_1/test/dummy:0", v3.name)
336
337  @test_util.run_deprecated_v1
338  def test_enforces_no_extra_trainable_variables(self):
339    tmpl = template.make_template("s", function_with_create, trainable=True)
340
341    tmpl()
342    with self.assertRaises(ValueError):
343      tmpl()
344
345  @test_util.run_in_graph_and_eager_modes
346  def test_enforces_no_extra_trainable_variables_eager(self):
347    tmpl = template.make_template("s",
348                                  function_with_side_create,
349                                  trainable=True)
350
351    tmpl(name="1")
352    with self.assertRaises(ValueError):
353      tmpl(name="2")
354
355  def test_permits_extra_non_trainable_variables(self):
356    tmpl = template.make_template("s", function_with_create, trainable=False)
357    self.assertIs(tmpl(), tmpl())
358
359  def test_permits_extra_non_trainable_variables_eager(self):
360    with context.eager_mode():
361      tmpl = template.make_template("s",
362                                    function_with_side_create,
363                                    trainable=False)
364      self.assertIs(tmpl(name="1"), tmpl(name="2"))
365
366  @test_util.run_in_graph_and_eager_modes
367  def test_internal_variable_reuse(self):
368
369    def nested():
370      with variable_scope.variable_scope("nested") as vs:
371        v1 = variable_scope.get_variable(
372            "x", initializer=init_ops.zeros_initializer(), shape=[])
373      with variable_scope.variable_scope(vs, reuse=True):
374        v2 = variable_scope.get_variable("x")
375      self.assertIs(v1, v2)
376      return v1
377
378    tmpl1 = template.make_template("s1", nested)
379    tmpl2 = template.make_template("s1", nested)
380
381    v1 = tmpl1()
382    v2 = tmpl1()
383    v3 = tmpl2()
384    self.assertIs(v1, v2)
385    self.assertIsNot(v1, v3)
386    self.assertEqual("s1/nested/x:0", v1.name)
387    self.assertEqual("s1_1/nested/x:0", v3.name)
388
389  @test_util.run_in_graph_and_eager_modes
390  def test_nested_templates(self):
391
392    def nested_template():
393      nested1 = template.make_template("nested", variable_scoped_function)
394      nested2 = template.make_template("nested", variable_scoped_function)
395      v1 = nested1()
396      v2 = nested2()
397
398      # nested1 and nested2 should not share variables
399      self.assertIsNot(v1, v2)
400
401      # Variables created by nested1 should be isolated from variables
402      # created by nested2.
403      self.assertEqual(1, len(nested1.variables))
404      self.assertEqual(1, len(nested2.variables))
405      self.assertIs(nested1.variables[0], v1)
406      self.assertIs(nested2.variables[0], v2)
407      self.assertEqual(1, len(nested1.trainable_variables))
408      self.assertEqual(1, len(nested2.trainable_variables))
409      self.assertIs(nested1.trainable_variables[0], v1)
410      self.assertIs(nested2.trainable_variables[0], v2)
411      self.assertEqual(len(nested1.non_trainable_variables), 0)
412      self.assertEqual(len(nested2.non_trainable_variables), 0)
413      return v1, v2
414
415    tmpl1 = template.make_template("s1", nested_template)
416    tmpl2 = template.make_template("s1", nested_template)
417
418    v1, v2 = tmpl1()
419    v3, v4 = tmpl1()
420    v5, v6 = tmpl2()
421
422    # The second invocation of tmpl1 should reuse the variables
423    # created in the first invocation.
424    self.assertIs(v1, v3)
425    self.assertIs(v2, v4)
426    for v, w in zip(tmpl1.variables, [v1, v2]):
427      self.assertIs(v, w)
428    for v, w in zip(tmpl1.trainable_variables, [v1, v2]):
429      self.assertIs(v, w)
430    self.assertEqual(len(tmpl1.non_trainable_variables), 0)
431
432    # tmpl1 and tmpl2 should not share variables.
433    self.assertIsNot(v1, v5)
434    self.assertIsNot(v2, v6)
435    for v, w in zip(tmpl2.variables, [v5, v6]):
436      self.assertIs(v, w)
437    for v, w in zip(tmpl2.trainable_variables, [v5, v6]):
438      self.assertIs(v, w)
439    self.assertEqual(len(tmpl2.non_trainable_variables), 0)
440    self.assertEqual("s1/nested/dummy:0", v1.name)
441    self.assertEqual("s1/nested_1/dummy:0", v2.name)
442    self.assertEqual("s1_1/nested/dummy:0", v5.name)
443    self.assertEqual("s1_1/nested_1/dummy:0", v6.name)
444
445    self.assertEqual(["nested", "nested_1"], list(tmpl1._trackable_children()))
446
447  @test_util.run_in_graph_and_eager_modes
448  def test_nested_templates_with_defun(self):
449
450    def variable_scoped_function_no_return_value(trainable=True):
451      # defun cannot compile functions that return non-Tensor objects
452      _ = variable_scope.get_variable(
453          "dummy",
454          shape=[1],
455          trainable=trainable,
456          initializer=init_ops.zeros_initializer())
457
458    def nested_template():
459      nested1 = template.make_template_internal(
460          "nested",
461          variable_scoped_function_no_return_value,
462          create_graph_function_=True)
463      nested2 = template.make_template_internal(
464          "nested",
465          variable_scoped_function_no_return_value,
466          create_graph_function_=True)
467      nested1()
468      nested2()
469      v1 = nested1.variables
470      v2 = nested2.variables
471
472      self.assertEqual(len(v1), 1)
473      self.assertEqual(len(v2), 1)
474
475      # nested1 and nested2 should not share variables
476      self.assertIsNot(v1[0], v2[0])
477      self.assertIs(nested1.trainable_variables[0], v1[0])
478      self.assertIs(nested2.trainable_variables[0], v2[0])
479      self.assertEqual(len(nested1.non_trainable_variables), 0)
480      self.assertEqual(len(nested2.non_trainable_variables), 0)
481
482    tmpl1 = template.make_template("s1", nested_template)
483    tmpl2 = template.make_template("s1", nested_template)
484
485    tmpl1()
486    v1 = tmpl1.variables
487    tmpl1()
488    v2 = tmpl1.variables
489    tmpl2()
490    v3 = tmpl2.variables
491
492    # The second invocation of tmpl1 should reuse the variables
493    # created in the first invocation.
494    for v, w in zip(v1, v2):
495      self.assertIs(v, w)
496
497    # tmpl1 and tmpl2 should not share variables.
498    for v, w in zip(v1, v3):
499      self.assertIsNot(v, w)
500
501    self.assertEqual("s1/nested/dummy:0", v1[0].name)
502    self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
503    self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
504    self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)
505
506  def test_graph_function_no_name(self):
507    with context.eager_mode():
508
509      def f(_, y):
510        return y + 1
511
512      partial = functools.partial(f, 1.0)
513      tmpl = template.make_template_internal(
514          "a", partial, create_graph_function_=True)
515      self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0)
516
517  @test_util.run_in_graph_and_eager_modes
518  def test_immediate_scope_creation(self):
519    # Create templates in scope a then call in scope b. make_template should
520    # capture the scope the first time it is called, and make_immediate_template
521    # should capture the scope at construction time.
522    with variable_scope.variable_scope("ctor_scope"):
523      # Create scope here:
524      tmpl_immed = template.make_template("a", variable_scoped_function,
525                                          True)
526      # default: create scope at __call__
527      tmpl_defer = template.make_template(
528          "b", variable_scoped_function, False)
529    with variable_scope.variable_scope("call_scope"):
530      inner_imm_var = tmpl_immed()
531      inner_defer_var = tmpl_defer()
532    outer_imm_var = tmpl_immed()
533    outer_defer_var = tmpl_defer()
534
535    self.assertIsNot(inner_imm_var, inner_defer_var)
536    self.assertIs(outer_imm_var, inner_imm_var)
537    self.assertIs(outer_defer_var, inner_defer_var)
538
539    self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
540    self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
541
542  @test_util.run_in_graph_and_eager_modes
543  def test_scope_access(self):
544    # Ensure that we can access the scope inside the template, because the name
545    # of that scope may be different from the name we pass to make_template, due
546    # to having been made unique by variable_scope.
547    with variable_scope.variable_scope("foo"):
548      # Create two templates with the same name, ensure scopes are made unique.
549      ta = template.make_template("bar", variable_scoped_function, True)
550      tb = template.make_template("bar", variable_scoped_function, True)
551
552    # Ensure we can get the scopes before either template is actually called.
553    self.assertEqual(ta.variable_scope.name, "foo/bar")
554    self.assertEqual(tb.variable_scope.name, "foo/bar_1")
555
556    with variable_scope.variable_scope("foo_2"):
557      # Create a template which defers scope creation.
558      tc = template.make_template("blah", variable_scoped_function, False)
559
560    # Before we call the template, the scope property will be set to None.
561    self.assertEqual(tc.variable_scope, None)
562    tc()
563
564    # Template is called at the top level, so there is no preceding "foo_2".
565    self.assertEqual(tc.variable_scope.name, "blah")
566
567  @test_util.run_in_graph_and_eager_modes
568  def test_custom_getter(self):
569    # Custom getter that maintains call count and forwards to true getter
570    custom_getter_count = [0]
571
572    def custom_getter(getter, name, *args, **kwargs):
573      custom_getter_count[0] += 1
574      return getter(name, *args, **kwargs)
575
576    # Test that custom getter is called both when variables are created and
577    # subsequently accessed
578    tmpl1 = template.make_template(
579        "s1", variable_scoped_function, custom_getter_=custom_getter)
580    self.assertEqual(custom_getter_count[0], 0)
581    tmpl1()
582    self.assertEqual(custom_getter_count[0], 1)
583    tmpl1()
584    self.assertEqual(custom_getter_count[0], 2)
585
586    # Test that custom getter is called when the variable scope is created
587    # during construction
588    custom_getter_count[0] = 0
589    tmpl2 = template.make_template(
590        "s2",
591        variable_scoped_function,
592        custom_getter_=custom_getter,
593        create_scope_now_=True)
594    self.assertEqual(custom_getter_count[0], 0)
595    tmpl2()
596    self.assertEqual(custom_getter_count[0], 1)
597    tmpl2()
598    self.assertEqual(custom_getter_count[0], 2)
599
600  @test_util.run_in_graph_and_eager_modes
601  def test_fails_gracefully(self):
602    for create_scope_now in [True, False]:
603      def module_function_with_one_arg(inputs):
604        w = variable_scope.get_variable(
605            "w", shape=[1], initializer=init_ops.zeros_initializer())
606        return inputs * w
607
608      templatized_function = template.make_template(
609          "f1", module_function_with_one_arg,
610          create_scope_now_=create_scope_now)
611      data = array_ops.zeros([1])
612      try:
613        # Try to connect with a kwarg which is unsupported.
614        templatized_function(data, is_training=True)
615      except TypeError:
616        pass
617
618      # The failed __call__ hasn't modified the inner state.
619      self.assertFalse(templatized_function._variables_created)
620      templatized_function(data)
621      self.assertTrue(templatized_function._variables_created)
622
623  @test_util.run_in_graph_and_eager_modes
624  def test_name_scopes_for_variable_scopes(self):
625    # Test that name scopes are not unnecessarily uniquified (but are
626    # still uniquified when necessary).
627    def linear_module(x, output_size):
628      w = variable_scope.get_variable(
629          "w", shape=[x.get_shape()[1], output_size],
630          initializer=init_ops.zeros_initializer())
631      b = variable_scope.get_variable(
632          "b", shape=[output_size],
633          initializer=init_ops.zeros_initializer())
634      return (math_ops.matmul(x, w) + b), w
635
636    def make_linear_module(output_size, name):
637      return template.make_template(
638          name,
639          linear_module,
640          output_size=output_size,
641          create_scope_now_=True)
642
643    inputs = array_ops.ones((3, 4))
644
645    linear1 = make_linear_module(output_size=2, name="foo")
646    outputs_a, w1 = linear1(inputs)
647    outputs_b, _ = linear1(inputs)
648    self.assertEqual("foo", linear1.variable_scope.name)
649    self.assertEqual("foo/w:0", w1.name)
650    if not context.executing_eagerly():
651      self.assertEqual(
652          "foo/add:0", outputs_a.name,
653          "First application of template should get "
654          "same name scope as variables.")
655      self.assertEqual(
656          "foo_1/add:0", outputs_b.name,
657          "Second application of template should get "
658          "a freshly uniquified name scope.")
659
660    linear2 = make_linear_module(output_size=2, name="foo")
661    outputs_c, w2 = linear2(inputs)
662    outputs_d, _ = linear2(inputs)
663    self.assertEqual(
664        "foo_1", linear2.variable_scope.name,
665        "New template gets a freshly uniquified variable scope "
666        "because 'foo' is already taken.")
667    self.assertEqual("foo_1/w:0", w2.name)
668    if not context.executing_eagerly():
669      self.assertEqual(
670          "foo_1_1/add:0", outputs_c.name,
671          "First application of template would get "
672          "same name scope as variables, but 'foo_1' is already "
673          "a name scope.")
674      self.assertEqual(
675          "foo_1_2/add:0", outputs_d.name,
676          "Second application of template should also get "
677          "a freshly uniquified name scope.")
678
679  @test_util.run_in_graph_and_eager_modes
680  def test_global_variables(self):
681    # Make sure global_variables are created.
682    with variable_scope.variable_scope("foo"):
683      # Create two templates with the same name, ensure scopes are made unique.
684      ta = template.make_template("bar", variable_scoped_function, True)
685      if context.executing_eagerly():
686        tb = template.make_template("s", function_with_side_create,
687                                    trainable=False)
688      else:
689        tb = template.make_template("s", function_with_create, trainable=False)
690
691    # Initially there are not variables created.
692    self.assertEqual([], list(ta.global_variables))
693    self.assertEqual([], list(tb.global_variables))
694    # After calling there are variables created.
695    ta()
696    tb()
697    # Ensure we can get the scopes before either template is actually called.
698    self.assertEqual(1, len(ta.global_variables))
699    self.assertEqual(2, len(tb.global_variables))
700
701  @test_util.run_in_graph_and_eager_modes
702  def test_trainable_variables(self):
703    # Make sure trainable_variables are created.
704    with variable_scope.variable_scope("foo2"):
705      # Create two templates with the same name, ensure scopes are made unique.
706      ta = template.make_template("bar", variable_scoped_function, True)
707      tb = template.make_template("bar", variable_scoped_function, True)
708
709    # Initially there are not variables created.
710    self.assertEqual([], list(ta.trainable_variables))
711    self.assertEqual([], list(tb.trainable_variables))
712    # After calling there are variables created.
713    ta()
714    tb()
715    # Ensure we can get the scopes before either template is actually called.
716    self.assertEqual(1, len(ta.trainable_variables))
717    self.assertEqual(1, len(tb.trainable_variables))
718    # None non-trainable variable was created.
719    self.assertEqual([], list(ta.non_trainable_variables))
720    self.assertEqual([], list(tb.non_trainable_variables))
721    # Ensure variables returns all the variables.
722    self.assertEqual(1, len(ta.variables))
723    self.assertEqual(1, len(tb.variables))
724
725  @test_util.run_in_graph_and_eager_modes
726  def test_non_trainable_variables(self):
727    # Make sure non_trainable_variables are created.
728    with variable_scope.variable_scope("foo2"):
729      ta = template.make_template("a", variable_scoped_function,
730                                  trainable=True)
731      tb = template.make_template("b", variable_scoped_function,
732                                  trainable=False)
733    # Initially there are not variables created.
734    self.assertEqual([], list(ta.variables))
735    self.assertEqual([], list(tb.variables))
736    # After calling there are variables created.
737    ta()
738    tb()
739    # Check the trainable and non_trainable variables.
740    self.assertEqual(1, len(ta.trainable_variables))
741    self.assertEqual([], list(ta.non_trainable_variables))
742
743    self.assertEqual([], list(tb.trainable_variables))
744    self.assertEqual(1, len(tb.non_trainable_variables))
745    # Ensure variables returns all the variables.
746    self.assertEqual(1, len(ta.variables))
747    self.assertEqual(1, len(tb.variables))
748
749  # TODO(apassos) handle local variables in Eager
750  @test_util.run_deprecated_v1
751  def test_local_variables(self):
752    # Make sure trainable_variables are created.
753    with variable_scope.variable_scope("foo3"):
754      # Create two templates with the same name, ensure scopes are made unique.
755      ta = template.make_template("bar", variable_scoped_function, True)
756      tb = template.make_template("bar",
757                                  variable_scoped_function_with_local_variable)
758
759    # Initially there are not variables created.
760    self.assertEqual([], list(ta.local_variables))
761    self.assertEqual([], list(tb.local_variables))
762    # After calling there are variables created.
763    ta()
764    tb()
765    # Ensure we can get the scopes before either template is actually called.
766    self.assertEqual(0, len(ta.local_variables))
767    self.assertEqual(1, len(tb.local_variables))
768
769  @test_util.run_in_graph_and_eager_modes
770  def test_make_template_with_defun(self):
771
772    def variable_scoped_function_no_return_value(scope_name):
773      # defun cannot compile functions that return non-Tensor objects
774      with variable_scope.variable_scope(scope_name):
775        _ = variable_scope.get_variable(
776            "dummy", shape=[1], initializer=init_ops.zeros_initializer())
777
778    tmpl = template.make_template_internal(
779        "s1",
780        variable_scoped_function_no_return_value,
781        create_graph_function_=True,
782        scope_name="test")
783
784    # The first invocation of tmpl1 creates variables, the second should
785    # be executed as a graph function.
786    tmpl()
787    v1 = tmpl.variables
788    tmpl()
789    v2 = tmpl.variables
790
791    self.assertEqual(len(v1), len(v2))
792    for v, w in zip(v1, v2):
793      self.assertIs(v, w)
794    self.assertEqual("s1/test/dummy:0", v1[0].name)
795
796
797if __name__ == "__main__":
798  test.main()
799