1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core conversion logic, serves as main point of access.""" 16 17import functools 18import inspect 19import sys 20import unittest 21 22from tensorflow.python.autograph.core import config 23from tensorflow.python.autograph.pyct import cache 24from tensorflow.python.autograph.pyct import inspect_utils 25from tensorflow.python.autograph.utils import ag_logging as logging 26from tensorflow.python.eager import function 27from tensorflow.python.util import tf_inspect 28 29 30_ALLOWLIST_CACHE = cache.UnboundInstanceCache() 31 32 33def _is_of_known_loaded_module(f, module_name): 34 mod = sys.modules.get(module_name, None) 35 if mod is None: 36 return False 37 if any(v is not None for v in mod.__dict__.values() if f is v): 38 return True 39 return False 40 41 42def _is_known_loaded_type(f, module_name, entity_name): 43 """Tests whether the function or method is an instance of a known type.""" 44 if (module_name not in sys.modules or 45 not hasattr(sys.modules[module_name], entity_name)): 46 return False 47 type_entity = getattr(sys.modules[module_name], entity_name) 48 if isinstance(f, type_entity): 49 # The method if of this type. Example: 50 # 51 # o = ClassType() 52 # function(o.method)() 53 return True 54 # Note: inspect is required here, to avoid unpacking tf.function decorators. 55 if inspect.ismethod(f): 56 # The unbound method if of this type. Example: 57 # 58 # class ClassType: 59 # @function 60 # def method(self): 61 # ... 62 # o = ClassType() 63 # o.method() 64 if isinstance(f.__func__, type_entity): 65 return True 66 return False 67 68 69def is_unsupported(o): 70 """Checks whether an entity is supported by AutoGraph at all.""" 71 72 # TODO(b/122265385): Remove this bypass. 73 if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or 74 _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')): 75 logging.warning( 76 '{} appears to be decorated by wrapt, which is not yet supported' 77 ' by AutoGraph. The function will run as-is.' 78 ' You may still apply AutoGraph before the wrapt decorator.'.format(o)) 79 logging.log(2, 'Permanently allowed: %s: wrapt decorated', o) 80 return True 81 82 if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'): 83 logging.log(2, 'Permanently allowed: %s: lru_cache', o) 84 return True 85 86 # Constructors are permanently allowed. 87 # TODO(mdan): Toggle as experimental feature instead. 88 # TODO(b/124016764): Remove this limitation. 89 if inspect_utils.isconstructor(o): 90 logging.log(2, 'Permanently allowed: %s: constructor', o) 91 return True 92 93 # Other built-in modules are permanently allowed. 94 # TODO(mdan): Figure out how to do this consistently for all stdlib modules. 95 if any( 96 _is_of_known_loaded_module(o, m) 97 for m in ('collections', 'pdb', 'copy', 'inspect', 're')): 98 logging.log(2, 'Permanently allowed: %s: part of builtin module', o) 99 return True 100 101 # Custom ops and kernels are also permanently allowed. 102 # See tensorflow.framework.load_library. 103 if (hasattr(o, '__module__') and 104 hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')): 105 logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o) 106 return True 107 108 return False 109 110 111# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. 112def is_allowlisted( 113 o, check_call_override=True, allow_namedtuple_subclass=False): 114 """Checks whether an entity is allowed for use in graph mode. 115 116 Examples of allowed entities include all members of the tensorflow 117 package. 118 119 Args: 120 o: A Python entity. 121 check_call_override: Reserved for internal use. When set to `False`, it 122 disables the rule according to which classes are allowed if their 123 __call__ method is allowed. 124 allow_namedtuple_subclass: Reserved for internal use. When `True`, 125 namedtuple subclasses are not allowed. 126 127 Returns: 128 Boolean 129 """ 130 # TODO(b/120224672): Fix this. 131 if isinstance(o, functools.partial): 132 # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since 133 # functools.partial objects do not have a __module__ attribute. 134 m = functools 135 else: 136 m = tf_inspect.getmodule(o) 137 138 # Examples of callables that lack a __module__ property include builtins. 139 if hasattr(m, '__name__'): 140 for rule in config.CONVERSION_RULES: 141 action = rule.get_action(m) 142 if action == config.Action.CONVERT: 143 logging.log(2, 'Not allowed: %s: %s', o, rule) 144 return False 145 elif action == config.Action.DO_NOT_CONVERT: 146 logging.log(2, 'Allowlisted: %s: %s', o, rule) 147 return True 148 149 # The check for __code__ below is because isgeneratorfunction crashes 150 # without one. 151 if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o): 152 logging.log(2, 'Allowlisted: %s: generator functions are not converted', o) 153 return True 154 155 if (check_call_override and not tf_inspect.isclass(o) and 156 hasattr(o, '__call__')): 157 # Callable objects: allowed if their __call__ method is. 158 # The type check avoids infinite recursion around the __call__ method 159 # of function objects. 160 if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck 161 logging.log(2, 'Allowlisted: %s: object __call__ allowed', o) 162 return True 163 164 owner_class = None 165 if tf_inspect.ismethod(o): 166 # Methods of allowed classes are also allowed, even if they are 167 # bound via user subclasses. 168 # 169 # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is 170 # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also 171 # allowed. 172 # 173 # class Custom(tf.Foo): 174 # pass 175 # 176 # baz = Custom() 177 # 178 # For the example above, if `Custom` did overload `bar`, then it would no 179 # longer be allowed. 180 181 owner_class = inspect_utils.getmethodclass(o) 182 if owner_class is function.TfMethodTarget: 183 owner_class = o.__self__.target_class 184 if owner_class is not None: 185 if issubclass(owner_class, unittest.TestCase): 186 logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o) 187 return True 188 189 owner_class = inspect_utils.getdefiningclass(o, owner_class) 190 if is_allowlisted( 191 owner_class, 192 check_call_override=False, 193 allow_namedtuple_subclass=True): 194 logging.log(2, 'Allowlisted: %s: owner is allowed %s', o, 195 owner_class) 196 return True 197 198 if inspect_utils.isnamedtuple(o): 199 # Due to the way they're constructed, namedtuple types cannot be converted 200 # because they don't expose source code. But we assume they are safe for 201 # graph mode since they are just containers. 202 if allow_namedtuple_subclass: 203 if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__): 204 logging.log(2, 'Allowlisted: %s: named tuple', o) 205 return True 206 else: 207 logging.log(2, 'Allowlisted: %s: named tuple or subclass', o) 208 return True 209 210 logging.log(2, 'Not allowed: %s: default rule', o) 211 return False 212 213 214def is_in_allowlist_cache(entity, options): 215 try: 216 return _ALLOWLIST_CACHE.has(entity, options) 217 except TypeError: 218 # Catch-all for entities that are unhashable or don't allow weakrefs. 219 return False 220 221 222def cache_allowlisted(entity, options): 223 try: 224 _ALLOWLIST_CACHE[entity][options] = True 225 except TypeError: 226 # Catch-all for entities that are unhashable or don't allow weakrefs. 227 pass 228