xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/error_interpolation.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Function for interpolating formatted errors from the TensorFlow runtime.
16
17Exposes the function `interpolate` to interpolate messages with tags of the form
18{{type name}}.
19"""
20
21import collections
22import os
23import re
24import site
25import traceback
26
27from tensorflow.core.protobuf import graph_debug_info_pb2
28
29_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
30_TAG_REGEX = fr"{{{{(?P<type>{_NAME_REGEX}) (?P<name>{_NAME_REGEX})}}}}"
31_INTERPOLATION_REGEX = fr"(?P<sep>.*?)(?P<tag>{_TAG_REGEX})"
32_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
33
34_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
35
36
37# Remove the last three path components from this module's file (i.e.
38# python/framework/error_interpolation.py) so that we have an absolute path
39# prefix to the root of the installation.
40_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
41    os.path.dirname(os.path.dirname(__file__)))
42
43# Sub-directories under the common prefix that are considered part of the
44# framework.
45# Note that keras code lives outside of tensorflow directory, we need to walk
46# up the directory tree and find it.
47_FRAMEWORK_PATH_PREFIXES = [
48    os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
49    os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
50    os.path.join(os.path.dirname(_FRAMEWORK_COMMON_PREFIX),
51                 "py", "keras") + os.sep,
52]
53
54# Patterns of filename patterns that should be considered internal to
55# the TensorFlow framework.
56_FRAMEWORK_FILENAME_PATTERNS = [
57    re.compile(r"<embedded"),
58]
59
60# This is for OSS keras, since the package is load from local python env,
61# but we don't know exactly where it is installed. Matching to keyword
62# "keras".
63try:
64  _FRAMEWORK_PATH_PREFIXES.extend([
65      os.path.join(package_path, "keras") + os.sep
66      for package_path in site.getsitepackages() + [site.getusersitepackages()]
67  ])
68except AttributeError:
69  # if site.getsitepackages is not available somehow, we just use the "keras" as
70  # the keyword to do the match.
71  _FRAMEWORK_FILENAME_PATTERNS.append(re.compile(r"keras"))
72
73# Patterns of filename patterns that should be considered external to
74# TensorFlow regardless of framework prefix match.
75_EXTERNAL_FILENAME_PATTERNS = [
76    # Explicitly treat test frames as not part of the framework.
77    re.compile(r"_test\.py$"),
78]
79
80
81def parse_message(message):
82  """Extract function tags and node tags from a message.
83
84  Tags are named tuples representing the string {{type name}}. For example,
85  in "123{{node Foo}}456{{function_node Bar}}789", there are two tags: a node
86  tag and a function tag.
87
88  Args:
89    message: An error message, possibly from an OpError.
90
91  Returns:
92    A tuple containing the original message with function nodes stripped,
93    function tags, and node tags.
94
95    For example, if message is "123{{node Foo}}456{{function_node Bar}}789"
96    then this function returns ("123{{node Foo}}456789",
97    [_ParseTag("function_node", "Bar")], [_ParseTag("node", "Foo")]).
98  """
99  error_message = []
100  func_tags = []
101  node_tags = []
102  pos = 0
103  for match in re.finditer(_INTERPOLATION_PATTERN, message):
104    parsed_tag = _ParseTag(match.group("type"), match.group("name"))
105    if parsed_tag.type == "function_node":
106      error_message.append(match.group("sep"))
107      func_tags.append(parsed_tag)
108    else:
109      error_message.append(match.group())
110      node_tags.append(parsed_tag)
111    pos = match.end()
112  error_message.append(message[pos:])
113  return "".join(error_message), func_tags, node_tags
114
115
116def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
117  """Return a summary of an op's device function stack.
118
119  Args:
120    name: The name of the op.
121    device_assignment_list: The op._device_assignments list.
122    prefix:  An optional string prefix used before each line of the multi-
123        line string returned by this function.
124
125  Returns:
126    A multi-line string similar to:
127        Device assignments active during op 'foo' creation:
128          with tf.device(/cpu:0): <test_1.py:27>
129          with tf.device(some_func<foo.py, 123>): <test_2.py:38>
130    The first line will have no padding to its left by default.  Subsequent
131    lines will have two spaces of left-padding.  Use the prefix argument
132    to increase indentation.
133  """
134  if not device_assignment_list:
135    message = "No device assignments were active during op '%s' creation."
136    message %= name
137    return prefix + message
138
139  str_list = []
140  str_list.append(
141      "%sDevice assignments active during op '%s' creation:" % (prefix, name))
142
143  for traceable_obj in device_assignment_list:
144    location_summary = "<{file}:{line}>".format(
145        file=traceable_obj.filename, line=traceable_obj.lineno)
146    subs = {
147        "prefix": prefix,
148        "indent": "  ",
149        "dev_name": traceable_obj.obj,
150        "loc": location_summary,
151    }
152    str_list.append(
153        "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
154
155  return "\n".join(str_list)
156
157
158def _compute_device_assignment_summary_from_op(op, prefix=""):
159  # pylint: disable=protected-access
160  return _compute_device_summary_from_list(op.name, op._device_assignments,
161                                           prefix)
162  # pylint: enable=protected-access
163
164
165def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
166  """Return a summary of an op's colocation stack.
167
168  Args:
169    name: The op name.
170    colocation_dict: The op._colocation_dict.
171    prefix:  An optional string prefix used before each line of the multi-
172        line string returned by this function.
173
174  Returns:
175    A multi-line string similar to:
176        Node-device colocations active during op creation:
177          with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
178          with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>
179    The first line will have no padding to its left by default.  Subsequent
180    lines will have two spaces of left-padding.  Use the prefix argument
181    to increase indentation.
182  """
183  if not colocation_dict:
184    message = "No node-device colocations were active during op '%s' creation."
185    message %= name
186    return prefix + message
187
188  str_list = []
189  str_list.append("%sNode-device colocations active during op '%s' creation:" %
190                  (prefix, name))
191
192  for coloc_name, location in colocation_dict.items():
193    location_summary = "<{file}:{line}>".format(
194        file=location.filename, line=location.lineno)
195    subs = {
196        "prefix": prefix,
197        "indent": "  ",
198        "name": coloc_name,
199        "loc": location_summary,
200    }
201    str_list.append(
202        "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
203
204  return "\n".join(str_list)
205
206
207def _compute_colocation_summary_from_op(op, prefix=""):
208  """Fetch colocation file, line, and nesting and return a summary string."""
209  # pylint: disable=protected-access
210  return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
211                                               prefix)
212  # pylint: enable=protected-access
213
214
215def _is_framework_filename(filename):
216  """Returns whether a filename should be considered a part of the framework.
217
218  A file is part of the framework if it does not match a pattern in
219  _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
220  _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
221
222  Args:
223    filename: A filename string.
224
225  Returns:
226    Whether the filename should be considered to be internal to the
227    TensorFlow framework for the purposes of reporting errors.
228  """
229  for pattern in _EXTERNAL_FILENAME_PATTERNS:
230    if pattern.search(filename):
231      return False
232  for pattern in _FRAMEWORK_FILENAME_PATTERNS:
233    if pattern.search(filename):
234      return True
235  for prefix in _FRAMEWORK_PATH_PREFIXES:
236    if filename.startswith(prefix):
237      return True
238  return False
239
240
241def _find_index_of_defining_frame(tb):
242  """Return index in op.traceback with first 'useful' frame.
243
244  This method reads through the stack stored in op.traceback looking for the
245  innermost frame which (hopefully) belongs to the caller.  It accomplishes this
246  by rejecting frames deemed to be part of the TensorFlow framework (by
247  pattern matching the filename).
248
249  Args:
250    tb: A list of traceback frames (as from Operation.traceback).
251
252  Returns:
253    Integer index into op.traceback where the first non-TF file was found
254    (innermost to outermost), or 0 (for the outermost stack frame) if all files
255    came from TensorFlow.
256  """
257  # Index 0 of traceback is the outermost frame.
258  size = len(tb)
259  filenames = [frame.filename for frame in tb]
260  # We process the filenames from the innermost frame to outermost.
261  for idx, filename in enumerate(reversed(filenames)):
262    is_framework = _is_framework_filename(filename)
263    if not is_framework:
264      # Consider this to be the defining frame.
265      return size - idx - 1
266  return 0
267
268
269# TODO(feyu): follow up with users of this function (saved model)
270# to see what 'useful' means and whether we can obliviate this.
271def _compute_useful_frames(tb, num):
272  """Return a list of frames, which form a 'useful' stack.
273
274  Starting from the defining frame to the outermost one, this method computes
275  the contiguous portion of the 'useful' stack trace and returns the selected
276  frames.
277
278  Args:
279    tb: A list of traceback frames (as from Operation.traceback).
280    num: total number of frames to return.
281
282  Returns:
283    A list of frames.
284  """
285  defining_frame_index = _find_index_of_defining_frame(tb)
286  # The stack trace is collected from two lines before the defining frame in the
287  # model file to the outermost with `num` frames at most. These two extra lines
288  # are included from the TensorFlow library to give the context which node is
289  # defined.
290  innermost_excluded = min(defining_frame_index + 2 + 1, len(tb))
291  outermost_included = max(innermost_excluded - num, 0)
292  return tb[outermost_included:innermost_excluded]
293
294
295def create_graph_debug_info_def(func_named_operations):
296  """Construct and returns a `GraphDebugInfo` protocol buffer.
297
298  Args:
299    func_named_operations: An iterable of (func_name, op.Operation) tuples
300      where the Operation instances have a _traceback members. The func_name
301      should be the empty string for operations in the top-level Graph.
302
303  Returns:
304    GraphDebugInfo protocol buffer.
305
306  Raises:
307    TypeError: If the arguments are not of the correct proto buffer type.
308  """
309  # Creates an empty GraphDebugInfoDef proto.
310  graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
311
312  # Gets the file names and line numbers for the exported node names. Also
313  # collects the unique file names.
314  all_file_names = set()
315  node_to_trace = {}
316  for func_name, op in func_named_operations:
317    if op.traceback is None:
318      continue
319    # Gets the stack trace of the operation and then the file location.
320    node_name = op.name + "@" + func_name
321    node_to_trace[node_name] = _compute_useful_frames(op.traceback, 10)
322    for frame in node_to_trace[node_name]:
323      all_file_names.add(frame.filename)
324
325  # Sets the `files` field in the GraphDebugInfo proto
326  graph_debug_info_def.files.extend(all_file_names)
327
328  # Builds a mapping between file names and index of the `files` field, so we
329  # only store the indexes for the nodes in the GraphDebugInfo.
330  file_to_index = dict(
331      [(y, x) for x, y in enumerate(graph_debug_info_def.files)])
332
333  # Creates the FileLineCol proto for each node and sets the value in the
334  # GraphDebugInfo proto. We only store the file name index for each node to
335  # save the storage space.
336  for node_name, frames in node_to_trace.items():
337    trace_def = graph_debug_info_def.traces[node_name]
338    for frame in reversed(frames):
339      trace_def.file_line_cols.add(
340          file_index=file_to_index[frame.filename],
341          line=frame.lineno)
342
343  return graph_debug_info_def
344
345
346def _compute_field_dict(op):
347  r"""Return a dictionary mapping interpolation tokens to values.
348
349  Args:
350    op: op.Operation object.
351
352  Returns:
353    A dictionary mapping string tokens to string values.  The keys are shown
354    below along with example values.
355    {
356      "file": "tool_utils.py",
357      "lineno": "124",
358      "line": "  source code line",
359      "defined_at": " (defined at tool_utils.py:124)",
360      "colocations":
361          '''Node-device colocations active during op creation:
362               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
363               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
364      "devices":
365          '''Device assignments active during op 'foo' creation:
366               with tf.device(/cpu:0): <test_1.py:27>
367               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
368      "devs_and_colocs": A concatenation of colocations and devices, e.g.
369          '''Node-device colocations active during op creation:
370               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
371               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
372             Device assignments active during op 'foo' creation:
373               with tf.device(/cpu:0): <test_1.py:27>
374               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
375    }
376  """
377  # TODO(xjun): colocation and device info are not displayed. Consider
378  # removing them or using vlog.
379  colocation_summary = _compute_colocation_summary_from_op(op)
380  device_summary = _compute_device_assignment_summary_from_op(op)
381  combined_summary = "\n".join([colocation_summary, device_summary])
382
383  if op.traceback is None:
384    # Some ops synthesized on as part of function or control flow definition
385    # do not have tracebacks.
386    filename = "<unknown>"
387    definition_traceback = ""
388    lineno = 0
389    line = ""
390    defined_at = "<unknown>"
391  else:
392    frame = op.traceback.last_user_frame()
393    filename = frame.filename
394    definition_traceback = traceback.format_list(op.traceback.get_user_frames())
395    lineno = frame.lineno
396    line = frame.line
397    defined_at = f"{filename}:{lineno:d}"
398
399  field_dict = {
400      "colocations": colocation_summary,
401      "devices": device_summary,
402      "devs_and_colocs": combined_summary,
403      "defined_at": defined_at,
404      "file": filename,
405      "lineno": lineno,
406      "line": line,
407      "definition_traceback": definition_traceback,
408  }
409  return field_dict
410
411
412def _build_node_error_message(op):
413  """Returns the formatted error message for the given op.
414
415  Args:
416    op: The node.
417
418  Returns:
419    The formatted error message for the given op with traceback.
420  """
421  node_error_message = [
422      f"Detected at node {op.name!r} defined at (most recent call last):"
423  ]
424  field_dict = _compute_field_dict(op)
425
426  # Add node traceback.
427  for frame in field_dict["definition_traceback"]:
428    if "<embedded" not in frame:
429      node_error_message.extend(
430          [f"  {line}" for line in frame.split("\n") if line.strip()])
431
432  # Add node name.
433  node_error_message.append(f"Node: {op.name!r}")
434
435  return "\n".join(node_error_message)
436
437
438def interpolate(message, graph):
439  """Interpolates an error message.
440
441  The error message can contain tags of form `{{node_type node_name}}`
442  which will be parsed to identify the tf.Graph and op. If the op contains
443  traceback, the traceback will be attached to the error message.
444
445  Args:
446    message: A string to interpolate.
447    graph: ops.Graph object containing all nodes referenced in the error
448        message.
449
450  Returns:
451    The error message string with node definition traceback.
452  """
453  parsed_messaged, _, node_tags = parse_message(message)
454  error_message = ["Graph execution error:", ""]
455  for tag in node_tags:
456    try:
457      op = graph.get_operation_by_name(tag.name)
458    except KeyError:
459      continue
460    else:
461      error_message.append(_build_node_error_message(op))
462
463  error_message.append(parsed_messaged.strip())
464  return "\n".join(error_message)
465