xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/ast_util_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ast_util module."""
16
17import ast
18import collections
19import textwrap
20
21import gast
22
23from tensorflow.python.autograph.pyct import anno
24from tensorflow.python.autograph.pyct import ast_util
25from tensorflow.python.autograph.pyct import loader
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.autograph.pyct import pretty_printer
28from tensorflow.python.autograph.pyct import qual_names
29from tensorflow.python.platform import test
30
31
32class AstUtilTest(test.TestCase):
33
34  def assertAstMatches(self, actual_node, expected_node_src):
35    expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
36    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
37        pretty_printer.fmt(expected_node),
38        pretty_printer.fmt(actual_node))
39    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
40
41  def setUp(self):
42    super(AstUtilTest, self).setUp()
43    self._invocation_counts = collections.defaultdict(lambda: 0)
44
45  def test_rename_symbols_basic(self):
46    node = parser.parse('a + b')
47    node = qual_names.resolve(node)
48
49    node = ast_util.rename_symbols(
50        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
51    source = parser.unparse(node, include_encoding_marker=False)
52    expected_node_src = 'renamed_a + b'
53
54    self.assertIsInstance(node.value.left.id, str)
55    self.assertAstMatches(node, source)
56    self.assertAstMatches(node, expected_node_src)
57
58  def test_rename_symbols_attributes(self):
59    node = parser.parse('b.c = b.c.d')
60    node = qual_names.resolve(node)
61
62    node = ast_util.rename_symbols(
63        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
64
65    source = parser.unparse(node, include_encoding_marker=False)
66    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
67
68  def test_rename_symbols_nonlocal(self):
69    node = parser.parse('nonlocal a, b, c')
70    node = qual_names.resolve(node)
71
72    node = ast_util.rename_symbols(
73        node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})
74
75    source = parser.unparse(node, include_encoding_marker=False)
76    self.assertEqual(source.strip(), 'nonlocal a, renamed_b, c')
77
78  def test_rename_symbols_global(self):
79    node = parser.parse('global a, b, c')
80    node = qual_names.resolve(node)
81
82    node = ast_util.rename_symbols(
83        node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})
84
85    source = parser.unparse(node, include_encoding_marker=False)
86    self.assertEqual(source.strip(), 'global a, renamed_b, c')
87
88  def test_rename_symbols_annotations(self):
89    node = parser.parse('a[i]')
90    node = qual_names.resolve(node)
91    anno.setanno(node, 'foo', 'bar')
92    orig_anno = anno.getanno(node, 'foo')
93
94    node = ast_util.rename_symbols(node,
95                                   {qual_names.QN('a'): qual_names.QN('b')})
96
97    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
98
99  def test_rename_symbols_function(self):
100    node = parser.parse('def f():\n  pass')
101    node = ast_util.rename_symbols(node,
102                                   {qual_names.QN('f'): qual_names.QN('f1')})
103
104    source = parser.unparse(node, include_encoding_marker=False)
105    self.assertEqual(source.strip(), 'def f1():\n    pass')
106
107  def test_copy_clean(self):
108    node = parser.parse(
109        textwrap.dedent("""
110      def f(a):
111        return a + 1
112    """))
113    setattr(node, '__foo', 'bar')
114    new_node = ast_util.copy_clean(node)
115    self.assertIsNot(new_node, node)
116    self.assertFalse(hasattr(new_node, '__foo'))
117
118  def test_copy_clean_preserves_annotations(self):
119    node = parser.parse(
120        textwrap.dedent("""
121      def f(a):
122        return a + 1
123    """))
124    anno.setanno(node, 'foo', 'bar')
125    anno.setanno(node, 'baz', 1)
126    new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
127    self.assertEqual(anno.getanno(new_node, 'foo'), 'bar')
128    self.assertFalse(anno.hasanno(new_node, 'baz'))
129
130  def test_keywords_to_dict(self):
131    keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
132    d = ast_util.keywords_to_dict(keywords)
133    # Make sure we generate a usable dict node by attaching it to a variable and
134    # compiling everything.
135    node = parser.parse('def f(b): pass')
136    node.body.append(ast.Return(d))
137    result, _, _ = loader.load_ast(node)
138    self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
139
140  def assertMatch(self, target_str, pattern_str):
141    node = parser.parse_expression(target_str)
142    pattern = parser.parse_expression(pattern_str)
143    self.assertTrue(ast_util.matches(node, pattern))
144
145  def assertNoMatch(self, target_str, pattern_str):
146    node = parser.parse_expression(target_str)
147    pattern = parser.parse_expression(pattern_str)
148    self.assertFalse(ast_util.matches(node, pattern))
149
150  def test_matches_symbols(self):
151    self.assertMatch('foo', '_')
152    self.assertNoMatch('foo()', '_')
153    self.assertMatch('foo + bar', 'foo + _')
154    self.assertNoMatch('bar + bar', 'foo + _')
155    self.assertNoMatch('foo - bar', 'foo + _')
156
157  def test_matches_function_args(self):
158    self.assertMatch('super(Foo, self).__init__(arg1, arg2)',
159                     'super(_).__init__(_)')
160    self.assertMatch('super().__init__()', 'super(_).__init__(_)')
161    self.assertNoMatch('super(Foo, self).bar(arg1, arg2)',
162                       'super(_).__init__(_)')
163    self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)')
164    self.assertNoMatch('super(Foo, self).__init__()',
165                       'super(Bar, _).__init__(_)')
166
167  def _mock_apply_fn(self, target, source):
168    target = parser.unparse(target, include_encoding_marker=False)
169    source = parser.unparse(source, include_encoding_marker=False)
170    self._invocation_counts[(target.strip(), source.strip())] += 1
171
172  def test_apply_to_single_assignments_dynamic_unpack(self):
173    node = parser.parse('a, b, c = d')
174    ast_util.apply_to_single_assignments(node.targets, node.value,
175                                         self._mock_apply_fn)
176    self.assertDictEqual(self._invocation_counts, {
177        ('a', 'd[0]'): 1,
178        ('b', 'd[1]'): 1,
179        ('c', 'd[2]'): 1,
180    })
181
182  def test_apply_to_single_assignments_static_unpack(self):
183    node = parser.parse('a, b, c = d, e, f')
184    ast_util.apply_to_single_assignments(node.targets, node.value,
185                                         self._mock_apply_fn)
186    self.assertDictEqual(self._invocation_counts, {
187        ('a', 'd'): 1,
188        ('b', 'e'): 1,
189        ('c', 'f'): 1,
190    })
191
192  def test_parallel_walk(self):
193    src = """
194      def f(a):
195        return a + 1
196    """
197    node = parser.parse(textwrap.dedent(src))
198    for child_a, child_b in ast_util.parallel_walk(node, node):
199      self.assertEqual(child_a, child_b)
200
201  def test_parallel_walk_string_leaves(self):
202    src = """
203      def f(a):
204        global g
205    """
206    node = parser.parse(textwrap.dedent(src))
207    for child_a, child_b in ast_util.parallel_walk(node, node):
208      self.assertEqual(child_a, child_b)
209
210  def test_parallel_walk_inconsistent_trees(self):
211    node_1 = parser.parse(
212        textwrap.dedent("""
213      def f(a):
214        return a + 1
215    """))
216    node_2 = parser.parse(
217        textwrap.dedent("""
218      def f(a):
219        return a + (a * 2)
220    """))
221    node_3 = parser.parse(
222        textwrap.dedent("""
223      def f(a):
224        return a + 2
225    """))
226    with self.assertRaises(ValueError):
227      for _ in ast_util.parallel_walk(node_1, node_2):
228        pass
229    # There is not particular reason to reject trees that differ only in the
230    # value of a constant.
231    # TODO(mdan): This should probably be allowed.
232    with self.assertRaises(ValueError):
233      for _ in ast_util.parallel_walk(node_1, node_3):
234        pass
235
236  def assertLambdaNodes(self, matching_nodes, expected_bodies):
237    self.assertEqual(len(matching_nodes), len(expected_bodies))
238    for node in matching_nodes:
239      self.assertIsInstance(node, gast.Lambda)
240      self.assertIn(
241          parser.unparse(node.body, include_encoding_marker=False).strip(),
242          expected_bodies)
243
244
245if __name__ == '__main__':
246  test.main()
247