1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Code transformation exceptions.""" 16 17import collections 18 19from tensorflow.python.autograph.pyct import origin_info 20from tensorflow.python.util import traceback_utils 21 22 23class FrameInfo( 24 collections.namedtuple('FrameInfo', 25 ('filename', 'lineno', 'function_name', 'code', 26 'is_converted', 'is_allowlisted'))): 27 28 __slots__ = () 29 30 31def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): 32 """Summarizes inner traceback frames up to the call to a given function. 33 34 This functions locates the innermost (i.e. most recent) frame that corresponds 35 to code that can be mapped by source_map originated from, and returns a 36 translated stack trace ending at that frame. If no such frame is found, the 37 entire stack trace is summarized. 38 39 For example, the following code: 40 41 def f(): 42 for i in tf.range(1): 43 z = y + i # z only defined here 44 45 Would generate this traceback: 46 47 <converted code> 48 ag__.for_stmt(...) 49 <for_stmt> 50 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 51 <_known_len_tf_for_stmt> 52 _disallow_undefs_into_loop(*init_state) 53 <_disallow_undefs_into_loop> 54 raise ... 55 56 Which is then processed into: 57 58 <f> 59 for i in tf.range(1): 60 <for_stmt> 61 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 62 <_known_len_tf_for_stmt> 63 _disallow_undefs_into_loop(*init_state) 64 <_disallow_undefs_into_loop> 65 raise ... 66 67 Args: 68 tb: traceback.FrameSummary, The traceback corresponding to an error. 69 Typically, the output of traceback.Summary.extract(capture_locals=True). 70 source_map: Dict[LineLocation, OriginInfo], a source map as created by 71 origin_info.create_source_map. 72 converter_filename: str, the file path of the converted module. Call frames 73 corresponding to this module are elided and their preceding frames are 74 marked as allowlisted. Note that frames enclosing converted code are 75 dropped using a different mechanism. 76 77 Returns: 78 List[FrameInfo] 79 """ 80 result_frames = [] 81 for filename, line_number, function_name, text in reversed(tb): 82 83 loc = origin_info.LineLocation(filename=filename, lineno=line_number) 84 if loc in source_map: 85 origin = source_map[loc] 86 fi = FrameInfo( 87 filename=origin.loc.filename, 88 lineno=origin.loc.lineno, 89 function_name=origin.function_name, 90 code=origin.source_code_line, 91 is_converted=True, 92 is_allowlisted=False) 93 result_frames.append(fi) 94 break 95 96 if filename == converter_filename: 97 if result_frames: 98 prev = result_frames[-1] 99 assert not prev.is_converted # See the if above. 100 fi = FrameInfo( 101 filename=prev.filename, 102 lineno=prev.lineno, 103 function_name=prev.function_name, 104 code=prev.code, 105 is_converted=False, 106 is_allowlisted=True) 107 result_frames[-1] = fi 108 continue 109 110 fi = FrameInfo( 111 filename=filename, 112 lineno=line_number, 113 function_name=function_name, 114 code=text, 115 is_converted=False, 116 is_allowlisted=False) 117 result_frames.append(fi) 118 119 return tuple(result_frames) 120 121 122KNOWN_STRING_CONSTRUCTOR_ERRORS = ( 123 AssertionError, 124 AttributeError, 125 NameError, 126 NotImplementedError, 127 RuntimeError, 128 StopIteration, 129 TypeError, 130 UnboundLocalError, 131 ValueError, 132) 133 134 135# KeyError escapes newlines in strings. We create a special subclass 136# that doesn't do that. Overriding the name for display purposes; hopefully 137# that won't create too many surprises. 138class MultilineMessageKeyError(KeyError): 139 140 def __init__(self, message, original_key): 141 super(MultilineMessageKeyError, self).__init__(original_key) 142 self.__message = message 143 144 def __str__(self): 145 return self.__message 146 147MultilineMessageKeyError.__name__ = KeyError.__name__ 148 149 150class ErrorMetadataBase(object): 151 """Container objects attached to exceptions raised in user code. 152 153 This metadata allows re-raising exceptions that occur in generated code, with 154 a custom error message that includes a stack trace relative to user-readable 155 code from which the generated code originated. 156 """ 157 158 __slots__ = ('translated_stack', 'cause_message') 159 160 def __init__(self, callsite_tb, cause_metadata, cause_message, source_map, 161 converter_filename): 162 translated_stack = _stack_trace_inside_mapped_code( 163 callsite_tb, source_map, converter_filename) 164 165 if cause_metadata is None: 166 self.translated_stack = translated_stack 167 self.cause_message = cause_message 168 else: 169 # Daisy chain the translated stacks. 170 self.translated_stack = ( 171 cause_metadata.translated_stack + (translated_stack[-1],)) 172 self.cause_message = cause_metadata.cause_message 173 174 def get_message(self): 175 """Returns the message for the underlying exception.""" 176 lines = [] 177 178 lines.append('in user code:') 179 lines.append('') 180 181 for frame_info in reversed(self.translated_stack): 182 if (traceback_utils.is_traceback_filtering_enabled() and 183 not traceback_utils.include_frame(frame_info.filename)): 184 continue 185 186 # Same format with Python traceback. 187 formatted_line = (f' File "{frame_info.filename}", line ' 188 f'{frame_info.lineno}, in {frame_info.function_name}') 189 if frame_info.is_converted: 190 formatted_line += ' *' 191 elif frame_info.is_allowlisted: 192 formatted_line += ' **' 193 lines.append(formatted_line) 194 195 if frame_info.code is None: 196 code_snippet = '<source unavailable>' 197 else: 198 code_snippet = frame_info.code.strip() 199 lines.append(' {}'.format(code_snippet)) 200 201 lines.append('') 202 203 message_lines = self.cause_message.split('\n') 204 for i in range(len(message_lines)): 205 message_lines[i] = ' ' + message_lines[i] 206 lines.extend(message_lines) 207 208 lines.append('') 209 210 return '\n'.join(lines) 211 212 def create_exception(self, source_error): 213 """Creates exception from source_error.""" 214 preferred_type = type(source_error) 215 to_ret = None 216 if preferred_type.__init__ is Exception.__init__: 217 to_ret = preferred_type(self.get_message()) 218 if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS: 219 to_ret = preferred_type(self.get_message()) 220 elif preferred_type is KeyError: 221 to_ret = MultilineMessageKeyError(self.get_message(), self.cause_message) 222 223 if to_ret is not None: 224 return to_ret.with_traceback(source_error.__traceback__) 225 226 def to_exception(self, source_error): 227 exc = self.create_exception(source_error) 228 exc.__suppress_context__ = True 229 exc.ag_error_metadata = self 230 return exc 231