1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Function for interpolating formatted errors from the TensorFlow runtime. 16 17Exposes the function `interpolate` to interpolate messages with tags of the form 18{{type name}}. 19""" 20 21import collections 22import os 23import re 24import site 25import traceback 26 27from tensorflow.core.protobuf import graph_debug_info_pb2 28 29_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?" 30_TAG_REGEX = fr"{{{{(?P<type>{_NAME_REGEX}) (?P<name>{_NAME_REGEX})}}}}" 31_INTERPOLATION_REGEX = fr"(?P<sep>.*?)(?P<tag>{_TAG_REGEX})" 32_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL) 33 34_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"]) 35 36 37# Remove the last three path components from this module's file (i.e. 38# python/framework/error_interpolation.py) so that we have an absolute path 39# prefix to the root of the installation. 40_FRAMEWORK_COMMON_PREFIX = os.path.dirname( 41 os.path.dirname(os.path.dirname(__file__))) 42 43# Sub-directories under the common prefix that are considered part of the 44# framework. 45# Note that keras code lives outside of tensorflow directory, we need to walk 46# up the directory tree and find it. 47_FRAMEWORK_PATH_PREFIXES = [ 48 os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep, 49 os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep, 50 os.path.join(os.path.dirname(_FRAMEWORK_COMMON_PREFIX), 51 "py", "keras") + os.sep, 52] 53 54# Patterns of filename patterns that should be considered internal to 55# the TensorFlow framework. 56_FRAMEWORK_FILENAME_PATTERNS = [ 57 re.compile(r"<embedded"), 58] 59 60# This is for OSS keras, since the package is load from local python env, 61# but we don't know exactly where it is installed. Matching to keyword 62# "keras". 63try: 64 _FRAMEWORK_PATH_PREFIXES.extend([ 65 os.path.join(package_path, "keras") + os.sep 66 for package_path in site.getsitepackages() + [site.getusersitepackages()] 67 ]) 68except AttributeError: 69 # if site.getsitepackages is not available somehow, we just use the "keras" as 70 # the keyword to do the match. 71 _FRAMEWORK_FILENAME_PATTERNS.append(re.compile(r"keras")) 72 73# Patterns of filename patterns that should be considered external to 74# TensorFlow regardless of framework prefix match. 75_EXTERNAL_FILENAME_PATTERNS = [ 76 # Explicitly treat test frames as not part of the framework. 77 re.compile(r"_test\.py$"), 78] 79 80 81def parse_message(message): 82 """Extract function tags and node tags from a message. 83 84 Tags are named tuples representing the string {{type name}}. For example, 85 in "123{{node Foo}}456{{function_node Bar}}789", there are two tags: a node 86 tag and a function tag. 87 88 Args: 89 message: An error message, possibly from an OpError. 90 91 Returns: 92 A tuple containing the original message with function nodes stripped, 93 function tags, and node tags. 94 95 For example, if message is "123{{node Foo}}456{{function_node Bar}}789" 96 then this function returns ("123{{node Foo}}456789", 97 [_ParseTag("function_node", "Bar")], [_ParseTag("node", "Foo")]). 98 """ 99 error_message = [] 100 func_tags = [] 101 node_tags = [] 102 pos = 0 103 for match in re.finditer(_INTERPOLATION_PATTERN, message): 104 parsed_tag = _ParseTag(match.group("type"), match.group("name")) 105 if parsed_tag.type == "function_node": 106 error_message.append(match.group("sep")) 107 func_tags.append(parsed_tag) 108 else: 109 error_message.append(match.group()) 110 node_tags.append(parsed_tag) 111 pos = match.end() 112 error_message.append(message[pos:]) 113 return "".join(error_message), func_tags, node_tags 114 115 116def _compute_device_summary_from_list(name, device_assignment_list, prefix=""): 117 """Return a summary of an op's device function stack. 118 119 Args: 120 name: The name of the op. 121 device_assignment_list: The op._device_assignments list. 122 prefix: An optional string prefix used before each line of the multi- 123 line string returned by this function. 124 125 Returns: 126 A multi-line string similar to: 127 Device assignments active during op 'foo' creation: 128 with tf.device(/cpu:0): <test_1.py:27> 129 with tf.device(some_func<foo.py, 123>): <test_2.py:38> 130 The first line will have no padding to its left by default. Subsequent 131 lines will have two spaces of left-padding. Use the prefix argument 132 to increase indentation. 133 """ 134 if not device_assignment_list: 135 message = "No device assignments were active during op '%s' creation." 136 message %= name 137 return prefix + message 138 139 str_list = [] 140 str_list.append( 141 "%sDevice assignments active during op '%s' creation:" % (prefix, name)) 142 143 for traceable_obj in device_assignment_list: 144 location_summary = "<{file}:{line}>".format( 145 file=traceable_obj.filename, line=traceable_obj.lineno) 146 subs = { 147 "prefix": prefix, 148 "indent": " ", 149 "dev_name": traceable_obj.obj, 150 "loc": location_summary, 151 } 152 str_list.append( 153 "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs)) 154 155 return "\n".join(str_list) 156 157 158def _compute_device_assignment_summary_from_op(op, prefix=""): 159 # pylint: disable=protected-access 160 return _compute_device_summary_from_list(op.name, op._device_assignments, 161 prefix) 162 # pylint: enable=protected-access 163 164 165def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""): 166 """Return a summary of an op's colocation stack. 167 168 Args: 169 name: The op name. 170 colocation_dict: The op._colocation_dict. 171 prefix: An optional string prefix used before each line of the multi- 172 line string returned by this function. 173 174 Returns: 175 A multi-line string similar to: 176 Node-device colocations active during op creation: 177 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 178 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38> 179 The first line will have no padding to its left by default. Subsequent 180 lines will have two spaces of left-padding. Use the prefix argument 181 to increase indentation. 182 """ 183 if not colocation_dict: 184 message = "No node-device colocations were active during op '%s' creation." 185 message %= name 186 return prefix + message 187 188 str_list = [] 189 str_list.append("%sNode-device colocations active during op '%s' creation:" % 190 (prefix, name)) 191 192 for coloc_name, location in colocation_dict.items(): 193 location_summary = "<{file}:{line}>".format( 194 file=location.filename, line=location.lineno) 195 subs = { 196 "prefix": prefix, 197 "indent": " ", 198 "name": coloc_name, 199 "loc": location_summary, 200 } 201 str_list.append( 202 "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs)) 203 204 return "\n".join(str_list) 205 206 207def _compute_colocation_summary_from_op(op, prefix=""): 208 """Fetch colocation file, line, and nesting and return a summary string.""" 209 # pylint: disable=protected-access 210 return _compute_colocation_summary_from_dict(op.name, op._colocation_dict, 211 prefix) 212 # pylint: enable=protected-access 213 214 215def _is_framework_filename(filename): 216 """Returns whether a filename should be considered a part of the framework. 217 218 A file is part of the framework if it does not match a pattern in 219 _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in 220 _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix. 221 222 Args: 223 filename: A filename string. 224 225 Returns: 226 Whether the filename should be considered to be internal to the 227 TensorFlow framework for the purposes of reporting errors. 228 """ 229 for pattern in _EXTERNAL_FILENAME_PATTERNS: 230 if pattern.search(filename): 231 return False 232 for pattern in _FRAMEWORK_FILENAME_PATTERNS: 233 if pattern.search(filename): 234 return True 235 for prefix in _FRAMEWORK_PATH_PREFIXES: 236 if filename.startswith(prefix): 237 return True 238 return False 239 240 241def _find_index_of_defining_frame(tb): 242 """Return index in op.traceback with first 'useful' frame. 243 244 This method reads through the stack stored in op.traceback looking for the 245 innermost frame which (hopefully) belongs to the caller. It accomplishes this 246 by rejecting frames deemed to be part of the TensorFlow framework (by 247 pattern matching the filename). 248 249 Args: 250 tb: A list of traceback frames (as from Operation.traceback). 251 252 Returns: 253 Integer index into op.traceback where the first non-TF file was found 254 (innermost to outermost), or 0 (for the outermost stack frame) if all files 255 came from TensorFlow. 256 """ 257 # Index 0 of traceback is the outermost frame. 258 size = len(tb) 259 filenames = [frame.filename for frame in tb] 260 # We process the filenames from the innermost frame to outermost. 261 for idx, filename in enumerate(reversed(filenames)): 262 is_framework = _is_framework_filename(filename) 263 if not is_framework: 264 # Consider this to be the defining frame. 265 return size - idx - 1 266 return 0 267 268 269# TODO(feyu): follow up with users of this function (saved model) 270# to see what 'useful' means and whether we can obliviate this. 271def _compute_useful_frames(tb, num): 272 """Return a list of frames, which form a 'useful' stack. 273 274 Starting from the defining frame to the outermost one, this method computes 275 the contiguous portion of the 'useful' stack trace and returns the selected 276 frames. 277 278 Args: 279 tb: A list of traceback frames (as from Operation.traceback). 280 num: total number of frames to return. 281 282 Returns: 283 A list of frames. 284 """ 285 defining_frame_index = _find_index_of_defining_frame(tb) 286 # The stack trace is collected from two lines before the defining frame in the 287 # model file to the outermost with `num` frames at most. These two extra lines 288 # are included from the TensorFlow library to give the context which node is 289 # defined. 290 innermost_excluded = min(defining_frame_index + 2 + 1, len(tb)) 291 outermost_included = max(innermost_excluded - num, 0) 292 return tb[outermost_included:innermost_excluded] 293 294 295def create_graph_debug_info_def(func_named_operations): 296 """Construct and returns a `GraphDebugInfo` protocol buffer. 297 298 Args: 299 func_named_operations: An iterable of (func_name, op.Operation) tuples 300 where the Operation instances have a _traceback members. The func_name 301 should be the empty string for operations in the top-level Graph. 302 303 Returns: 304 GraphDebugInfo protocol buffer. 305 306 Raises: 307 TypeError: If the arguments are not of the correct proto buffer type. 308 """ 309 # Creates an empty GraphDebugInfoDef proto. 310 graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() 311 312 # Gets the file names and line numbers for the exported node names. Also 313 # collects the unique file names. 314 all_file_names = set() 315 node_to_trace = {} 316 for func_name, op in func_named_operations: 317 if op.traceback is None: 318 continue 319 # Gets the stack trace of the operation and then the file location. 320 node_name = op.name + "@" + func_name 321 node_to_trace[node_name] = _compute_useful_frames(op.traceback, 10) 322 for frame in node_to_trace[node_name]: 323 all_file_names.add(frame.filename) 324 325 # Sets the `files` field in the GraphDebugInfo proto 326 graph_debug_info_def.files.extend(all_file_names) 327 328 # Builds a mapping between file names and index of the `files` field, so we 329 # only store the indexes for the nodes in the GraphDebugInfo. 330 file_to_index = dict( 331 [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) 332 333 # Creates the FileLineCol proto for each node and sets the value in the 334 # GraphDebugInfo proto. We only store the file name index for each node to 335 # save the storage space. 336 for node_name, frames in node_to_trace.items(): 337 trace_def = graph_debug_info_def.traces[node_name] 338 for frame in reversed(frames): 339 trace_def.file_line_cols.add( 340 file_index=file_to_index[frame.filename], 341 line=frame.lineno) 342 343 return graph_debug_info_def 344 345 346def _compute_field_dict(op): 347 r"""Return a dictionary mapping interpolation tokens to values. 348 349 Args: 350 op: op.Operation object. 351 352 Returns: 353 A dictionary mapping string tokens to string values. The keys are shown 354 below along with example values. 355 { 356 "file": "tool_utils.py", 357 "lineno": "124", 358 "line": " source code line", 359 "defined_at": " (defined at tool_utils.py:124)", 360 "colocations": 361 '''Node-device colocations active during op creation: 362 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 363 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 364 "devices": 365 '''Device assignments active during op 'foo' creation: 366 with tf.device(/cpu:0): <test_1.py:27> 367 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 368 "devs_and_colocs": A concatenation of colocations and devices, e.g. 369 '''Node-device colocations active during op creation: 370 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 371 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 372 Device assignments active during op 'foo' creation: 373 with tf.device(/cpu:0): <test_1.py:27> 374 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 375 } 376 """ 377 # TODO(xjun): colocation and device info are not displayed. Consider 378 # removing them or using vlog. 379 colocation_summary = _compute_colocation_summary_from_op(op) 380 device_summary = _compute_device_assignment_summary_from_op(op) 381 combined_summary = "\n".join([colocation_summary, device_summary]) 382 383 if op.traceback is None: 384 # Some ops synthesized on as part of function or control flow definition 385 # do not have tracebacks. 386 filename = "<unknown>" 387 definition_traceback = "" 388 lineno = 0 389 line = "" 390 defined_at = "<unknown>" 391 else: 392 frame = op.traceback.last_user_frame() 393 filename = frame.filename 394 definition_traceback = traceback.format_list(op.traceback.get_user_frames()) 395 lineno = frame.lineno 396 line = frame.line 397 defined_at = f"{filename}:{lineno:d}" 398 399 field_dict = { 400 "colocations": colocation_summary, 401 "devices": device_summary, 402 "devs_and_colocs": combined_summary, 403 "defined_at": defined_at, 404 "file": filename, 405 "lineno": lineno, 406 "line": line, 407 "definition_traceback": definition_traceback, 408 } 409 return field_dict 410 411 412def _build_node_error_message(op): 413 """Returns the formatted error message for the given op. 414 415 Args: 416 op: The node. 417 418 Returns: 419 The formatted error message for the given op with traceback. 420 """ 421 node_error_message = [ 422 f"Detected at node {op.name!r} defined at (most recent call last):" 423 ] 424 field_dict = _compute_field_dict(op) 425 426 # Add node traceback. 427 for frame in field_dict["definition_traceback"]: 428 if "<embedded" not in frame: 429 node_error_message.extend( 430 [f" {line}" for line in frame.split("\n") if line.strip()]) 431 432 # Add node name. 433 node_error_message.append(f"Node: {op.name!r}") 434 435 return "\n".join(node_error_message) 436 437 438def interpolate(message, graph): 439 """Interpolates an error message. 440 441 The error message can contain tags of form `{{node_type node_name}}` 442 which will be parsed to identify the tf.Graph and op. If the op contains 443 traceback, the traceback will be attached to the error message. 444 445 Args: 446 message: A string to interpolate. 447 graph: ops.Graph object containing all nodes referenced in the error 448 message. 449 450 Returns: 451 The error message string with node definition traceback. 452 """ 453 parsed_messaged, _, node_tags = parse_message(message) 454 error_message = ["Graph execution error:", ""] 455 for tag in node_tags: 456 try: 457 op = graph.get_operation_by_name(tag.name) 458 except KeyError: 459 continue 460 else: 461 error_message.append(_build_node_error_message(op)) 462 463 error_message.append(parsed_messaged.strip()) 464 return "\n".join(error_message) 465