1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ast_util module.""" 16 17import ast 18import collections 19import textwrap 20 21import gast 22 23from tensorflow.python.autograph.pyct import anno 24from tensorflow.python.autograph.pyct import ast_util 25from tensorflow.python.autograph.pyct import loader 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.autograph.pyct import pretty_printer 28from tensorflow.python.autograph.pyct import qual_names 29from tensorflow.python.platform import test 30 31 32class AstUtilTest(test.TestCase): 33 34 def assertAstMatches(self, actual_node, expected_node_src): 35 expected_node = gast.parse('({})'.format(expected_node_src)).body[0] 36 msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( 37 pretty_printer.fmt(expected_node), 38 pretty_printer.fmt(actual_node)) 39 self.assertTrue(ast_util.matches(actual_node, expected_node), msg) 40 41 def setUp(self): 42 super(AstUtilTest, self).setUp() 43 self._invocation_counts = collections.defaultdict(lambda: 0) 44 45 def test_rename_symbols_basic(self): 46 node = parser.parse('a + b') 47 node = qual_names.resolve(node) 48 49 node = ast_util.rename_symbols( 50 node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) 51 source = parser.unparse(node, include_encoding_marker=False) 52 expected_node_src = 'renamed_a + b' 53 54 self.assertIsInstance(node.value.left.id, str) 55 self.assertAstMatches(node, source) 56 self.assertAstMatches(node, expected_node_src) 57 58 def test_rename_symbols_attributes(self): 59 node = parser.parse('b.c = b.c.d') 60 node = qual_names.resolve(node) 61 62 node = ast_util.rename_symbols( 63 node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) 64 65 source = parser.unparse(node, include_encoding_marker=False) 66 self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') 67 68 def test_rename_symbols_nonlocal(self): 69 node = parser.parse('nonlocal a, b, c') 70 node = qual_names.resolve(node) 71 72 node = ast_util.rename_symbols( 73 node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) 74 75 source = parser.unparse(node, include_encoding_marker=False) 76 self.assertEqual(source.strip(), 'nonlocal a, renamed_b, c') 77 78 def test_rename_symbols_global(self): 79 node = parser.parse('global a, b, c') 80 node = qual_names.resolve(node) 81 82 node = ast_util.rename_symbols( 83 node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) 84 85 source = parser.unparse(node, include_encoding_marker=False) 86 self.assertEqual(source.strip(), 'global a, renamed_b, c') 87 88 def test_rename_symbols_annotations(self): 89 node = parser.parse('a[i]') 90 node = qual_names.resolve(node) 91 anno.setanno(node, 'foo', 'bar') 92 orig_anno = anno.getanno(node, 'foo') 93 94 node = ast_util.rename_symbols(node, 95 {qual_names.QN('a'): qual_names.QN('b')}) 96 97 self.assertIs(anno.getanno(node, 'foo'), orig_anno) 98 99 def test_rename_symbols_function(self): 100 node = parser.parse('def f():\n pass') 101 node = ast_util.rename_symbols(node, 102 {qual_names.QN('f'): qual_names.QN('f1')}) 103 104 source = parser.unparse(node, include_encoding_marker=False) 105 self.assertEqual(source.strip(), 'def f1():\n pass') 106 107 def test_copy_clean(self): 108 node = parser.parse( 109 textwrap.dedent(""" 110 def f(a): 111 return a + 1 112 """)) 113 setattr(node, '__foo', 'bar') 114 new_node = ast_util.copy_clean(node) 115 self.assertIsNot(new_node, node) 116 self.assertFalse(hasattr(new_node, '__foo')) 117 118 def test_copy_clean_preserves_annotations(self): 119 node = parser.parse( 120 textwrap.dedent(""" 121 def f(a): 122 return a + 1 123 """)) 124 anno.setanno(node, 'foo', 'bar') 125 anno.setanno(node, 'baz', 1) 126 new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) 127 self.assertEqual(anno.getanno(new_node, 'foo'), 'bar') 128 self.assertFalse(anno.hasanno(new_node, 'baz')) 129 130 def test_keywords_to_dict(self): 131 keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords 132 d = ast_util.keywords_to_dict(keywords) 133 # Make sure we generate a usable dict node by attaching it to a variable and 134 # compiling everything. 135 node = parser.parse('def f(b): pass') 136 node.body.append(ast.Return(d)) 137 result, _, _ = loader.load_ast(node) 138 self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) 139 140 def assertMatch(self, target_str, pattern_str): 141 node = parser.parse_expression(target_str) 142 pattern = parser.parse_expression(pattern_str) 143 self.assertTrue(ast_util.matches(node, pattern)) 144 145 def assertNoMatch(self, target_str, pattern_str): 146 node = parser.parse_expression(target_str) 147 pattern = parser.parse_expression(pattern_str) 148 self.assertFalse(ast_util.matches(node, pattern)) 149 150 def test_matches_symbols(self): 151 self.assertMatch('foo', '_') 152 self.assertNoMatch('foo()', '_') 153 self.assertMatch('foo + bar', 'foo + _') 154 self.assertNoMatch('bar + bar', 'foo + _') 155 self.assertNoMatch('foo - bar', 'foo + _') 156 157 def test_matches_function_args(self): 158 self.assertMatch('super(Foo, self).__init__(arg1, arg2)', 159 'super(_).__init__(_)') 160 self.assertMatch('super().__init__()', 'super(_).__init__(_)') 161 self.assertNoMatch('super(Foo, self).bar(arg1, arg2)', 162 'super(_).__init__(_)') 163 self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)') 164 self.assertNoMatch('super(Foo, self).__init__()', 165 'super(Bar, _).__init__(_)') 166 167 def _mock_apply_fn(self, target, source): 168 target = parser.unparse(target, include_encoding_marker=False) 169 source = parser.unparse(source, include_encoding_marker=False) 170 self._invocation_counts[(target.strip(), source.strip())] += 1 171 172 def test_apply_to_single_assignments_dynamic_unpack(self): 173 node = parser.parse('a, b, c = d') 174 ast_util.apply_to_single_assignments(node.targets, node.value, 175 self._mock_apply_fn) 176 self.assertDictEqual(self._invocation_counts, { 177 ('a', 'd[0]'): 1, 178 ('b', 'd[1]'): 1, 179 ('c', 'd[2]'): 1, 180 }) 181 182 def test_apply_to_single_assignments_static_unpack(self): 183 node = parser.parse('a, b, c = d, e, f') 184 ast_util.apply_to_single_assignments(node.targets, node.value, 185 self._mock_apply_fn) 186 self.assertDictEqual(self._invocation_counts, { 187 ('a', 'd'): 1, 188 ('b', 'e'): 1, 189 ('c', 'f'): 1, 190 }) 191 192 def test_parallel_walk(self): 193 src = """ 194 def f(a): 195 return a + 1 196 """ 197 node = parser.parse(textwrap.dedent(src)) 198 for child_a, child_b in ast_util.parallel_walk(node, node): 199 self.assertEqual(child_a, child_b) 200 201 def test_parallel_walk_string_leaves(self): 202 src = """ 203 def f(a): 204 global g 205 """ 206 node = parser.parse(textwrap.dedent(src)) 207 for child_a, child_b in ast_util.parallel_walk(node, node): 208 self.assertEqual(child_a, child_b) 209 210 def test_parallel_walk_inconsistent_trees(self): 211 node_1 = parser.parse( 212 textwrap.dedent(""" 213 def f(a): 214 return a + 1 215 """)) 216 node_2 = parser.parse( 217 textwrap.dedent(""" 218 def f(a): 219 return a + (a * 2) 220 """)) 221 node_3 = parser.parse( 222 textwrap.dedent(""" 223 def f(a): 224 return a + 2 225 """)) 226 with self.assertRaises(ValueError): 227 for _ in ast_util.parallel_walk(node_1, node_2): 228 pass 229 # There is not particular reason to reject trees that differ only in the 230 # value of a constant. 231 # TODO(mdan): This should probably be allowed. 232 with self.assertRaises(ValueError): 233 for _ in ast_util.parallel_walk(node_1, node_3): 234 pass 235 236 def assertLambdaNodes(self, matching_nodes, expected_bodies): 237 self.assertEqual(len(matching_nodes), len(expected_bodies)) 238 for node in matching_nodes: 239 self.assertIsInstance(node, gast.Lambda) 240 self.assertIn( 241 parser.unparse(node.body, include_encoding_marker=False).strip(), 242 expected_bodies) 243 244 245if __name__ == '__main__': 246 test.main() 247