xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/impl/conversion.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core conversion logic, serves as main point of access."""
16
17import functools
18import inspect
19import sys
20import unittest
21
22from tensorflow.python.autograph.core import config
23from tensorflow.python.autograph.pyct import cache
24from tensorflow.python.autograph.pyct import inspect_utils
25from tensorflow.python.autograph.utils import ag_logging as logging
26from tensorflow.python.eager import function
27from tensorflow.python.util import tf_inspect
28
29
30_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
31
32
33def _is_of_known_loaded_module(f, module_name):
34  mod = sys.modules.get(module_name, None)
35  if mod is None:
36    return False
37  if any(v is not None for v in mod.__dict__.values() if f is v):
38    return True
39  return False
40
41
42def _is_known_loaded_type(f, module_name, entity_name):
43  """Tests whether the function or method is an instance of a known type."""
44  if (module_name not in sys.modules or
45      not hasattr(sys.modules[module_name], entity_name)):
46    return False
47  type_entity = getattr(sys.modules[module_name], entity_name)
48  if isinstance(f, type_entity):
49    # The method if of this type. Example:
50    #
51    # o = ClassType()
52    # function(o.method)()
53    return True
54  # Note: inspect is required here, to avoid unpacking tf.function decorators.
55  if inspect.ismethod(f):
56    # The unbound method if of this type. Example:
57    #
58    # class ClassType:
59    #   @function
60    #   def method(self):
61    #     ...
62    # o = ClassType()
63    # o.method()
64    if isinstance(f.__func__, type_entity):
65      return True
66  return False
67
68
69def is_unsupported(o):
70  """Checks whether an entity is supported by AutoGraph at all."""
71
72  # TODO(b/122265385): Remove this bypass.
73  if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
74      _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
75    logging.warning(
76        '{} appears to be decorated by wrapt, which is not yet supported'
77        ' by AutoGraph. The function will run as-is.'
78        ' You may still apply AutoGraph before the wrapt decorator.'.format(o))
79    logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
80    return True
81
82  if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
83    logging.log(2, 'Permanently allowed: %s: lru_cache', o)
84    return True
85
86  # Constructors are permanently allowed.
87  # TODO(mdan): Toggle as experimental feature instead.
88  # TODO(b/124016764): Remove this limitation.
89  if inspect_utils.isconstructor(o):
90    logging.log(2, 'Permanently allowed: %s: constructor', o)
91    return True
92
93  # Other built-in modules are permanently allowed.
94  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
95  if any(
96      _is_of_known_loaded_module(o, m)
97      for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
98    logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
99    return True
100
101  # Custom ops and kernels are also permanently allowed.
102  # See tensorflow.framework.load_library.
103  if (hasattr(o, '__module__') and
104      hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
105    logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
106    return True
107
108  return False
109
110
111# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
112def is_allowlisted(
113    o, check_call_override=True, allow_namedtuple_subclass=False):
114  """Checks whether an entity is allowed for use in graph mode.
115
116  Examples of allowed entities include all members of the tensorflow
117  package.
118
119  Args:
120    o: A Python entity.
121    check_call_override: Reserved for internal use. When set to `False`, it
122      disables the rule according to which classes are allowed if their
123      __call__ method is allowed.
124    allow_namedtuple_subclass: Reserved for internal use. When `True`,
125      namedtuple subclasses are not allowed.
126
127  Returns:
128    Boolean
129  """
130  # TODO(b/120224672): Fix this.
131  if isinstance(o, functools.partial):
132    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
133    # functools.partial objects do not have a __module__ attribute.
134    m = functools
135  else:
136    m = tf_inspect.getmodule(o)
137
138  # Examples of callables that lack a __module__ property include builtins.
139  if hasattr(m, '__name__'):
140    for rule in config.CONVERSION_RULES:
141      action = rule.get_action(m)
142      if action == config.Action.CONVERT:
143        logging.log(2, 'Not allowed: %s: %s', o, rule)
144        return False
145      elif action == config.Action.DO_NOT_CONVERT:
146        logging.log(2, 'Allowlisted: %s: %s', o, rule)
147        return True
148
149  # The check for __code__ below is because isgeneratorfunction crashes
150  # without one.
151  if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
152    logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
153    return True
154
155  if (check_call_override and not tf_inspect.isclass(o) and
156      hasattr(o, '__call__')):
157    # Callable objects: allowed if their __call__ method is.
158    # The type check avoids infinite recursion around the __call__ method
159    # of function objects.
160    if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__):  # pylint: disable=unidiomatic-typecheck
161      logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
162      return True
163
164  owner_class = None
165  if tf_inspect.ismethod(o):
166    # Methods of allowed classes are also allowed, even if they are
167    # bound via user subclasses.
168    #
169    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
170    # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
171    # allowed.
172    #
173    #   class Custom(tf.Foo):
174    #     pass
175    #
176    #   baz = Custom()
177    #
178    # For the example above, if `Custom` did overload `bar`, then it would no
179    # longer be allowed.
180
181    owner_class = inspect_utils.getmethodclass(o)
182    if owner_class is function.TfMethodTarget:
183      owner_class = o.__self__.target_class
184    if owner_class is not None:
185      if issubclass(owner_class, unittest.TestCase):
186        logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
187        return True
188
189      owner_class = inspect_utils.getdefiningclass(o, owner_class)
190      if is_allowlisted(
191          owner_class,
192          check_call_override=False,
193          allow_namedtuple_subclass=True):
194        logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
195                    owner_class)
196        return True
197
198  if inspect_utils.isnamedtuple(o):
199    # Due to the way they're constructed, namedtuple types cannot be converted
200    # because they don't expose source code. But we assume they are safe for
201    # graph mode since they are just containers.
202    if allow_namedtuple_subclass:
203      if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
204        logging.log(2, 'Allowlisted: %s: named tuple', o)
205        return True
206    else:
207      logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
208      return True
209
210  logging.log(2, 'Not allowed: %s: default rule', o)
211  return False
212
213
214def is_in_allowlist_cache(entity, options):
215  try:
216    return _ALLOWLIST_CACHE.has(entity, options)
217  except TypeError:
218    # Catch-all for entities that are unhashable or don't allow weakrefs.
219    return False
220
221
222def cache_allowlisted(entity, options):
223  try:
224    _ALLOWLIST_CACHE[entity][options] = True
225  except TypeError:
226    # Catch-all for entities that are unhashable or don't allow weakrefs.
227    pass
228