xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/core/converter_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for converter module."""
16
17import imp
18
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.core import converter_testing
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import loader
23from tensorflow.python.autograph.pyct import parser
24from tensorflow.python.autograph.pyct import templates
25from tensorflow.python.platform import test
26
27
28class TestConverter(converter.Base):
29  pass
30
31
32class ConversionOptionsTest(converter_testing.TestCase):
33
34  def test_to_ast(self):
35    opts = converter.ConversionOptions()
36    opts_ast = opts.to_ast()
37
38    template = '''
39    def f():
40      return opts_ast
41    '''
42    opts_packed = templates.replace(template, opts_ast=opts_ast)
43
44    reparsed, _, _ = loader.load_ast(opts_packed)
45    fake_ag = imp.new_module('fake_ag')
46    fake_ag.ConversionOptions = converter.ConversionOptions
47    fake_ag.Feature = converter.Feature
48    reparsed.ag__ = fake_ag
49
50    reparsed_opts = reparsed.f()
51
52    self.assertEqual(opts.recursive, reparsed_opts.recursive)
53    self.assertEqual(opts.user_requested, False)
54    self.assertEqual(
55        opts.internal_convert_user_code,
56        reparsed_opts.internal_convert_user_code)
57    self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
58
59
60class ConverterBaseTest(converter_testing.TestCase):
61
62  def test_get_definition_directive_basic(self):
63
64    directive_key = object
65
66    def f():
67      a = 1
68      return a
69
70    _, node, ctx = self.transform(f, (), include_ast=True)
71
72    symbol_a = node.body[1].value
73    defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
74    defs.directives[directive_key] = {
75        'test_arg': parser.parse_expression('foo'),
76        'other_arg': parser.parse_expression('bar'),
77    }
78    c = TestConverter(ctx)
79    value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
80                                       None)
81    self.assertEqual(value.id, 'foo')
82
83  def test_get_definition_directive_default(self):
84
85    directive_key = object
86
87    def f():
88      a = 1
89      return a
90
91    _, node, ctx = self.transform(f, (), include_ast=True)
92
93    symbol_a = node.body[1].value
94    c = TestConverter(ctx)
95    value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
96                                       parser.parse_expression('default'))
97    self.assertEqual(value.id, 'default')
98
99  def test_get_definition_directive_multiple_consistent(self):
100
101    directive_key = object
102
103    def f():
104      a = 1
105      if a:
106        a = 2
107      return a
108
109    _, node, ctx = self.transform(f, (), include_ast=True)
110
111    symbol_a = node.body[2].value
112    defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
113    defs[0].directives[directive_key] = {
114        'test_arg': parser.parse_expression('foo'),
115        'other_arg': parser.parse_expression('bar'),
116    }
117    defs[1].directives[directive_key] = {
118        'test_arg': parser.parse_expression('foo'),
119        'other_arg': parser.parse_expression('baz'),
120    }
121    c = TestConverter(ctx)
122    value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
123                                       None)
124    self.assertEqual(value.id, 'foo')
125
126  def test_get_definition_directive_multiple_inconsistent(self):
127
128    directive_key = object
129
130    def f():
131      a = 1
132      if a:
133        a = 2
134      return a
135
136    _, node, ctx = self.transform(f, (), include_ast=True)
137
138    symbol_a = node.body[2].value
139    defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
140    defs[0].directives[directive_key] = {
141        'test_arg': parser.parse_expression('foo'),
142    }
143    defs[1].directives[directive_key] = {
144        'test_arg': parser.parse_expression('bar'),
145    }
146    c = TestConverter(ctx)
147    with self.assertRaises(ValueError):
148      c.get_definition_directive(symbol_a, directive_key, 'test_arg', None)
149
150
151if __name__ == '__main__':
152  test.main()
153