1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for converter module.""" 16 17import imp 18 19from tensorflow.python.autograph.core import converter 20from tensorflow.python.autograph.core import converter_testing 21from tensorflow.python.autograph.pyct import anno 22from tensorflow.python.autograph.pyct import loader 23from tensorflow.python.autograph.pyct import parser 24from tensorflow.python.autograph.pyct import templates 25from tensorflow.python.platform import test 26 27 28class TestConverter(converter.Base): 29 pass 30 31 32class ConversionOptionsTest(converter_testing.TestCase): 33 34 def test_to_ast(self): 35 opts = converter.ConversionOptions() 36 opts_ast = opts.to_ast() 37 38 template = ''' 39 def f(): 40 return opts_ast 41 ''' 42 opts_packed = templates.replace(template, opts_ast=opts_ast) 43 44 reparsed, _, _ = loader.load_ast(opts_packed) 45 fake_ag = imp.new_module('fake_ag') 46 fake_ag.ConversionOptions = converter.ConversionOptions 47 fake_ag.Feature = converter.Feature 48 reparsed.ag__ = fake_ag 49 50 reparsed_opts = reparsed.f() 51 52 self.assertEqual(opts.recursive, reparsed_opts.recursive) 53 self.assertEqual(opts.user_requested, False) 54 self.assertEqual( 55 opts.internal_convert_user_code, 56 reparsed_opts.internal_convert_user_code) 57 self.assertEqual(opts.optional_features, reparsed_opts.optional_features) 58 59 60class ConverterBaseTest(converter_testing.TestCase): 61 62 def test_get_definition_directive_basic(self): 63 64 directive_key = object 65 66 def f(): 67 a = 1 68 return a 69 70 _, node, ctx = self.transform(f, (), include_ast=True) 71 72 symbol_a = node.body[1].value 73 defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) 74 defs.directives[directive_key] = { 75 'test_arg': parser.parse_expression('foo'), 76 'other_arg': parser.parse_expression('bar'), 77 } 78 c = TestConverter(ctx) 79 value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', 80 None) 81 self.assertEqual(value.id, 'foo') 82 83 def test_get_definition_directive_default(self): 84 85 directive_key = object 86 87 def f(): 88 a = 1 89 return a 90 91 _, node, ctx = self.transform(f, (), include_ast=True) 92 93 symbol_a = node.body[1].value 94 c = TestConverter(ctx) 95 value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', 96 parser.parse_expression('default')) 97 self.assertEqual(value.id, 'default') 98 99 def test_get_definition_directive_multiple_consistent(self): 100 101 directive_key = object 102 103 def f(): 104 a = 1 105 if a: 106 a = 2 107 return a 108 109 _, node, ctx = self.transform(f, (), include_ast=True) 110 111 symbol_a = node.body[2].value 112 defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) 113 defs[0].directives[directive_key] = { 114 'test_arg': parser.parse_expression('foo'), 115 'other_arg': parser.parse_expression('bar'), 116 } 117 defs[1].directives[directive_key] = { 118 'test_arg': parser.parse_expression('foo'), 119 'other_arg': parser.parse_expression('baz'), 120 } 121 c = TestConverter(ctx) 122 value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', 123 None) 124 self.assertEqual(value.id, 'foo') 125 126 def test_get_definition_directive_multiple_inconsistent(self): 127 128 directive_key = object 129 130 def f(): 131 a = 1 132 if a: 133 a = 2 134 return a 135 136 _, node, ctx = self.transform(f, (), include_ast=True) 137 138 symbol_a = node.body[2].value 139 defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) 140 defs[0].directives[directive_key] = { 141 'test_arg': parser.parse_expression('foo'), 142 } 143 defs[1].directives[directive_key] = { 144 'test_arg': parser.parse_expression('bar'), 145 } 146 c = TestConverter(ctx) 147 with self.assertRaises(ValueError): 148 c.get_definition_directive(symbol_a, directive_key, 'test_arg', None) 149 150 151if __name__ == '__main__': 152 test.main() 153