1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for tests in this module.""" 16 17import contextlib 18import imp 19import inspect 20import io 21import sys 22 23from tensorflow.python.autograph.core import config 24from tensorflow.python.autograph.core import converter 25from tensorflow.python.autograph.impl import api 26from tensorflow.python.framework import ops 27from tensorflow.python.platform import test 28 29 30def allowlist(f): 31 """Helper that marks a callable as whtelitisted.""" 32 if 'allowlisted_module_for_testing' not in sys.modules: 33 allowlisted_mod = imp.new_module('allowlisted_module_for_testing') 34 sys.modules['allowlisted_module_for_testing'] = allowlisted_mod 35 config.CONVERSION_RULES = ( 36 (config.DoNotConvert('allowlisted_module_for_testing'),) + 37 config.CONVERSION_RULES) 38 39 f.__module__ = 'allowlisted_module_for_testing' 40 41 42def is_inside_generated_code(): 43 """Tests whether the caller is generated code. Implementation-specific.""" 44 frame = inspect.currentframe() 45 try: 46 frame = frame.f_back 47 48 internal_stack_functions = ('converted_call', '_call_unconverted') 49 # Walk up the stack until we're out of the internal functions. 50 while (frame is not None and 51 frame.f_code.co_name in internal_stack_functions): 52 frame = frame.f_back 53 if frame is None: 54 return False 55 56 return 'ag__' in frame.f_locals 57 finally: 58 del frame 59 60 61class TestingTranspiler(api.PyToTF): 62 """Testing version that only applies given transformations.""" 63 64 def __init__(self, converters, ag_overrides): 65 super(TestingTranspiler, self).__init__() 66 if isinstance(converters, (list, tuple)): 67 self._converters = converters 68 else: 69 self._converters = (converters,) 70 self.transformed_ast = None 71 self._ag_overrides = ag_overrides 72 73 def get_extra_locals(self): 74 retval = super(TestingTranspiler, self).get_extra_locals() 75 if self._ag_overrides: 76 modified_ag = imp.new_module('fake_autograph') 77 modified_ag.__dict__.update(retval['ag__'].__dict__) 78 modified_ag.__dict__.update(self._ag_overrides) 79 retval['ag__'] = modified_ag 80 return retval 81 82 def transform_ast(self, node, ctx): 83 node = self.initial_analysis(node, ctx) 84 85 for c in self._converters: 86 node = c.transform(node, ctx) 87 88 self.transformed_ast = node 89 self.transform_ctx = ctx 90 return node 91 92 93class TestCase(test.TestCase): 94 """Base class for unit tests in this module. Contains relevant utilities.""" 95 96 def setUp(self): 97 # AutoGraph tests must run in graph mode to properly test control flow. 98 self.graph = ops.Graph().as_default() 99 self.graph.__enter__() 100 101 def tearDown(self): 102 self.graph.__exit__(None, None, None) 103 104 @contextlib.contextmanager 105 def assertPrints(self, expected_result): 106 try: 107 out_capturer = io.StringIO() 108 sys.stdout = out_capturer 109 yield 110 self.assertEqual(out_capturer.getvalue(), expected_result) 111 finally: 112 sys.stdout = sys.__stdout__ 113 114 def transform( 115 self, f, converter_module, include_ast=False, ag_overrides=None): 116 program_ctx = converter.ProgramContext( 117 options=converter.ConversionOptions(recursive=True), 118 autograph_module=api) 119 120 tr = TestingTranspiler(converter_module, ag_overrides) 121 transformed, _, _ = tr.transform_function(f, program_ctx) 122 123 if include_ast: 124 return transformed, tr.transformed_ast, tr.transform_ctx 125 126 return transformed 127