xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/tf_should_use.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator that provides a warning if the wrapped object is never used."""
16import copy
17import sys
18import textwrap
19import traceback
20import types
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.platform import tf_logging
25from tensorflow.python.util import tf_decorator
26
27
28class _TFShouldUseHelper(object):
29  """Object stored in TFShouldUse-wrapped objects.
30
31  When it is deleted it will emit a warning or error if its `sate` method
32  has not been called by time of deletion, and Tensorflow is not executing
33  eagerly or inside a tf.function (which use autodeps and resolve the
34  main issues this wrapper warns about).
35  """
36
37  def __init__(self, type_, repr_, stack_frame, error_in_function,
38               warn_in_eager):
39    self._type = type_
40    self._repr = repr_
41    self._stack_frame = stack_frame
42    self._error_in_function = error_in_function
43    if context.executing_eagerly():
44      # If warn_in_eager, sated == False.  Otherwise true.
45      self._sated = not warn_in_eager
46    elif ops.inside_function():
47      if error_in_function:
48        self._sated = False
49        ops.add_exit_callback_to_default_func_graph(
50            lambda: self._check_sated(raise_error=True))
51      else:
52        self._sated = True
53    else:
54      # TF1 graph building mode
55      self._sated = False
56
57  def sate(self):
58    self._sated = True
59    self._type = None
60    self._repr = None
61    self._stack_frame = None
62    self._logging_module = None
63
64  def _check_sated(self, raise_error):
65    """Check if the object has been sated."""
66    if self._sated:
67      return
68    creation_stack = ''.join(
69        [line.rstrip()
70         for line in traceback.format_stack(self._stack_frame, limit=5)])
71    if raise_error:
72      try:
73        raise RuntimeError(
74            'Object was never used (type {}): {}.  If you want to mark it as '
75            'used call its "mark_used()" method.  It was originally created '
76            'here:\n{}'.format(self._type, self._repr, creation_stack))
77      finally:
78        self.sate()
79    else:
80      tf_logging.error(
81          '==================================\n'
82          'Object was never used (type {}):\n{}\nIf you want to mark it as '
83          'used call its "mark_used()" method.\nIt was originally created '
84          'here:\n{}\n'
85          '=================================='
86          .format(self._type, self._repr, creation_stack))
87
88  def __del__(self):
89    self._check_sated(raise_error=False)
90
91
92def _new__init__(self, wrapped_value, tf_should_use_helper):
93  # pylint: disable=protected-access
94  self._tf_should_use_helper = tf_should_use_helper
95  self._tf_should_use_wrapped_value = wrapped_value
96
97
98def _new__setattr__(self, key, value):
99  if key in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'):
100    return object.__setattr__(self, key, value)
101  return setattr(
102      object.__getattribute__(self, '_tf_should_use_wrapped_value'),
103      key, value)
104
105
106def _new__getattribute__(self, key):
107  if key not in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'):
108    object.__getattribute__(self, '_tf_should_use_helper').sate()
109  if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'):
110    return object.__getattribute__(self, key)
111  return getattr(
112      object.__getattribute__(self, '_tf_should_use_wrapped_value'), key)
113
114
115def _new_mark_used(self, *args, **kwargs):
116  object.__getattribute__(self, '_tf_should_use_helper').sate()
117  try:
118    mu = object.__getattribute__(
119        object.__getattribute__(self, '_tf_should_use_wrapped_value'),
120        'mark_used')
121    return mu(*args, **kwargs)
122  except AttributeError:
123    pass
124
125
126_WRAPPERS = {}
127
128
129def _get_wrapper(x, tf_should_use_helper):
130  """Create a wrapper for object x, whose class subclasses type(x).
131
132  The wrapper will emit a warning if it is deleted without any of its
133  properties being accessed or methods being called.
134
135  Args:
136    x: The instance to wrap.
137    tf_should_use_helper: The object that tracks usage.
138
139  Returns:
140    An object wrapping `x`, of type `type(x)`.
141  """
142  type_x = type(x)
143  memoized = _WRAPPERS.get(type_x, None)
144  if memoized:
145    return memoized(x, tf_should_use_helper)
146
147  tx = copy.deepcopy(type_x)
148  # Prefer using __orig_bases__, which preserve generic type arguments.
149  bases = getattr(tx, '__orig_bases__', tx.__bases__)
150
151  # Use types.new_class when available, which is preferred over plain type in
152  # some distributions.
153  if sys.version_info >= (3, 5):
154    def set_body(ns):
155      ns.update(tx.__dict__)
156      return ns
157
158    copy_tx = types.new_class(tx.__name__, bases, exec_body=set_body)
159  else:
160    copy_tx = type(tx.__name__, bases, dict(tx.__dict__))
161
162  copy_tx.__init__ = _new__init__
163  copy_tx.__getattribute__ = _new__getattribute__
164  copy_tx.mark_used = _new_mark_used
165  copy_tx.__setattr__ = _new__setattr__
166  _WRAPPERS[type_x] = copy_tx
167
168  return copy_tx(x, tf_should_use_helper)
169
170
171def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False):
172  """Wraps object x so that if it is never used, a warning is logged.
173
174  Args:
175    x: Python object.
176    error_in_function: Python bool.  If `True`, a `RuntimeError` is raised
177      if the returned value is never used when created during `tf.function`
178      tracing.
179    warn_in_eager: Python bool. If `True` raise warning if in Eager mode as well
180      as graph mode.
181
182  Returns:
183    An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
184    and is a very shallow wrapper for `x` which logs access into `x`.
185  """
186  if x is None or (isinstance(x, list) and not x):
187    return x
188
189  if context.executing_eagerly() and not warn_in_eager:
190    return x
191
192  if ops.inside_function() and not error_in_function:
193    # We don't currently log warnings in tf.function calls, so just skip it.
194    return x
195
196  # Extract the current frame for later use by traceback printing.
197  try:
198    raise ValueError()
199  except ValueError:
200    stack_frame = sys.exc_info()[2].tb_frame.f_back
201
202  tf_should_use_helper = _TFShouldUseHelper(
203      type_=type(x),
204      repr_=repr(x),
205      stack_frame=stack_frame,
206      error_in_function=error_in_function,
207      warn_in_eager=warn_in_eager)
208
209  return _get_wrapper(x, tf_should_use_helper)
210
211
212def should_use_result(fn=None, warn_in_eager=False, error_in_function=False):
213  """Function wrapper that ensures the function's output is used.
214
215  If the output is not used, a `logging.error` is logged.  If
216  `error_in_function` is set, then a `RuntimeError` will be raised at the
217  end of function tracing if the output is not used by that point.
218
219  An output is marked as used if any of its attributes are read, modified, or
220  updated.  Examples when the output is a `Tensor` include:
221
222  - Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`)
223  - Accessing a property (e.g. getting `t.name` or `t.op`).
224  - Calling `t.mark_used()`.
225
226  Note, certain behaviors cannot be tracked - for these the object may not
227  be marked as used.  Examples include:
228
229  - `t != 0`.  In this case, comparison is done on types / ids.
230  - `isinstance(t, tf.Tensor)`.  Similar to above.
231
232  Args:
233    fn: The function to wrap.
234    warn_in_eager: Whether to create warnings in Eager as well.
235    error_in_function: Whether to raise an error when creating a tf.function.
236
237  Returns:
238    The wrapped function.
239  """
240  def decorated(fn):
241    """Decorates the input function."""
242    def wrapped(*args, **kwargs):
243      return _add_should_use_warning(fn(*args, **kwargs),
244                                     warn_in_eager=warn_in_eager,
245                                     error_in_function=error_in_function)
246    fn_doc = fn.__doc__ or ''
247    split_doc = fn_doc.split('\n', 1)
248    if len(split_doc) == 1:
249      updated_doc = fn_doc
250    else:
251      brief, rest = split_doc
252      updated_doc = '\n'.join([brief, textwrap.dedent(rest)])
253
254    note = ('\n\nNote: The output of this function should be used. If it is '
255            'not, a warning will be logged or an error may be raised. '
256            'To mark the output as used, call its .mark_used() method.')
257    return tf_decorator.make_decorator(
258        target=fn,
259        decorator_func=wrapped,
260        decorator_name='should_use_result',
261        decorator_doc=updated_doc + note)
262
263  if fn is not None:
264    return decorated(fn)
265  else:
266    return decorated
267