xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/error_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code transformation exceptions."""
16
17import collections
18
19from tensorflow.python.autograph.pyct import origin_info
20from tensorflow.python.util import traceback_utils
21
22
23class FrameInfo(
24    collections.namedtuple('FrameInfo',
25                           ('filename', 'lineno', 'function_name', 'code',
26                            'is_converted', 'is_allowlisted'))):
27
28  __slots__ = ()
29
30
31def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
32  """Summarizes inner traceback frames up to the call to a given function.
33
34  This functions locates the innermost (i.e. most recent) frame that corresponds
35  to code that can be mapped by source_map originated from, and returns a
36  translated stack trace ending at that frame. If no such frame is found, the
37  entire stack trace is summarized.
38
39  For example, the following code:
40
41    def f():
42      for i in tf.range(1):
43        z = y + i  # z only defined here
44
45  Would generate this traceback:
46
47    <converted code>
48        ag__.for_stmt(...)
49    <for_stmt>
50        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
51    <_known_len_tf_for_stmt>
52        _disallow_undefs_into_loop(*init_state)
53    <_disallow_undefs_into_loop>
54        raise ...
55
56  Which is then processed into:
57
58    <f>
59        for i in tf.range(1):
60    <for_stmt>
61        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
62    <_known_len_tf_for_stmt>
63        _disallow_undefs_into_loop(*init_state)
64    <_disallow_undefs_into_loop>
65        raise ...
66
67  Args:
68    tb: traceback.FrameSummary, The traceback corresponding to an error.
69      Typically, the output of traceback.Summary.extract(capture_locals=True).
70    source_map: Dict[LineLocation, OriginInfo], a source map as created by
71      origin_info.create_source_map.
72    converter_filename: str, the file path of the converted module. Call frames
73      corresponding to this module are elided and their preceding frames are
74      marked as allowlisted. Note that frames enclosing converted code are
75      dropped using a different mechanism.
76
77  Returns:
78    List[FrameInfo]
79  """
80  result_frames = []
81  for filename, line_number, function_name, text in reversed(tb):
82
83    loc = origin_info.LineLocation(filename=filename, lineno=line_number)
84    if loc in source_map:
85      origin = source_map[loc]
86      fi = FrameInfo(
87          filename=origin.loc.filename,
88          lineno=origin.loc.lineno,
89          function_name=origin.function_name,
90          code=origin.source_code_line,
91          is_converted=True,
92          is_allowlisted=False)
93      result_frames.append(fi)
94      break
95
96    if filename == converter_filename:
97      if result_frames:
98        prev = result_frames[-1]
99        assert not prev.is_converted  # See the if above.
100        fi = FrameInfo(
101            filename=prev.filename,
102            lineno=prev.lineno,
103            function_name=prev.function_name,
104            code=prev.code,
105            is_converted=False,
106            is_allowlisted=True)
107        result_frames[-1] = fi
108      continue
109
110    fi = FrameInfo(
111        filename=filename,
112        lineno=line_number,
113        function_name=function_name,
114        code=text,
115        is_converted=False,
116        is_allowlisted=False)
117    result_frames.append(fi)
118
119  return tuple(result_frames)
120
121
122KNOWN_STRING_CONSTRUCTOR_ERRORS = (
123    AssertionError,
124    AttributeError,
125    NameError,
126    NotImplementedError,
127    RuntimeError,
128    StopIteration,
129    TypeError,
130    UnboundLocalError,
131    ValueError,
132)
133
134
135# KeyError escapes newlines in strings. We create a special subclass
136# that doesn't do that. Overriding the name for display purposes; hopefully
137# that won't create too many surprises.
138class MultilineMessageKeyError(KeyError):
139
140  def __init__(self, message, original_key):
141    super(MultilineMessageKeyError, self).__init__(original_key)
142    self.__message = message
143
144  def __str__(self):
145    return self.__message
146
147MultilineMessageKeyError.__name__ = KeyError.__name__
148
149
150class ErrorMetadataBase(object):
151  """Container objects attached to exceptions raised in user code.
152
153  This metadata allows re-raising exceptions that occur in generated code, with
154  a custom error message that includes a stack trace relative to user-readable
155  code from which the generated code originated.
156  """
157
158  __slots__ = ('translated_stack', 'cause_message')
159
160  def __init__(self, callsite_tb, cause_metadata, cause_message, source_map,
161               converter_filename):
162    translated_stack = _stack_trace_inside_mapped_code(
163        callsite_tb, source_map, converter_filename)
164
165    if cause_metadata is None:
166      self.translated_stack = translated_stack
167      self.cause_message = cause_message
168    else:
169      # Daisy chain the translated stacks.
170      self.translated_stack = (
171          cause_metadata.translated_stack + (translated_stack[-1],))
172      self.cause_message = cause_metadata.cause_message
173
174  def get_message(self):
175    """Returns the message for the underlying exception."""
176    lines = []
177
178    lines.append('in user code:')
179    lines.append('')
180
181    for frame_info in reversed(self.translated_stack):
182      if (traceback_utils.is_traceback_filtering_enabled() and
183          not traceback_utils.include_frame(frame_info.filename)):
184        continue
185
186      # Same format with Python traceback.
187      formatted_line = (f'    File "{frame_info.filename}", line '
188                        f'{frame_info.lineno}, in {frame_info.function_name}')
189      if frame_info.is_converted:
190        formatted_line += '  *'
191      elif frame_info.is_allowlisted:
192        formatted_line += '  **'
193      lines.append(formatted_line)
194
195      if frame_info.code is None:
196        code_snippet = '<source unavailable>'
197      else:
198        code_snippet = frame_info.code.strip()
199      lines.append('        {}'.format(code_snippet))
200
201    lines.append('')
202
203    message_lines = self.cause_message.split('\n')
204    for i in range(len(message_lines)):
205      message_lines[i] = '    ' + message_lines[i]
206    lines.extend(message_lines)
207
208    lines.append('')
209
210    return '\n'.join(lines)
211
212  def create_exception(self, source_error):
213    """Creates exception from source_error."""
214    preferred_type = type(source_error)
215    to_ret = None
216    if preferred_type.__init__ is Exception.__init__:
217      to_ret = preferred_type(self.get_message())
218    if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS:
219      to_ret = preferred_type(self.get_message())
220    elif preferred_type is KeyError:
221      to_ret = MultilineMessageKeyError(self.get_message(), self.cause_message)
222
223    if to_ret is not None:
224      return to_ret.with_traceback(source_error.__traceback__)
225
226  def to_exception(self, source_error):
227    exc = self.create_exception(source_error)
228    exc.__suppress_context__ = True
229    exc.ag_error_metadata = self
230    return exc
231