xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/profile_analyzer_cli.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Formats and displays profiling information."""
16
17import argparse
18import os
19import re
20
21import numpy as np
22
23from tensorflow.python.debug.cli import cli_shared
24from tensorflow.python.debug.cli import command_parser
25from tensorflow.python.debug.cli import debugger_cli_common
26from tensorflow.python.debug.cli import ui_factory
27from tensorflow.python.debug.lib import profiling
28from tensorflow.python.debug.lib import source_utils
29
30RL = debugger_cli_common.RichLine
31
32SORT_OPS_BY_OP_NAME = "node"
33SORT_OPS_BY_OP_TYPE = "op_type"
34SORT_OPS_BY_OP_TIME = "op_time"
35SORT_OPS_BY_EXEC_TIME = "exec_time"
36SORT_OPS_BY_START_TIME = "start_time"
37SORT_OPS_BY_LINE = "line"
38
39_DEVICE_NAME_FILTER_FLAG = "device_name_filter"
40_NODE_NAME_FILTER_FLAG = "node_name_filter"
41_OP_TYPE_FILTER_FLAG = "op_type_filter"
42
43
44class ProfileDataTableView(object):
45  """Table View of profiling data."""
46
47  def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US):
48    """Constructor.
49
50    Args:
51      profile_datum_list: List of `ProfileDatum` objects.
52      time_unit: must be in cli_shared.TIME_UNITS.
53    """
54    self._profile_datum_list = profile_datum_list
55    self.formatted_start_time = [
56        datum.start_time for datum in profile_datum_list]
57    self.formatted_op_time = [
58        cli_shared.time_to_readable_str(datum.op_time,
59                                        force_time_unit=time_unit)
60        for datum in profile_datum_list]
61    self.formatted_exec_time = [
62        cli_shared.time_to_readable_str(
63            datum.node_exec_stats.all_end_rel_micros,
64            force_time_unit=time_unit)
65        for datum in profile_datum_list]
66
67    self._column_names = ["Node",
68                          "Op Type",
69                          "Start Time (us)",
70                          "Op Time (%s)" % time_unit,
71                          "Exec Time (%s)" % time_unit,
72                          "Filename:Lineno(function)"]
73    self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
74                             SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
75                             SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
76
77  def value(self,
78            row,
79            col,
80            device_name_filter=None,
81            node_name_filter=None,
82            op_type_filter=None):
83    """Get the content of a cell of the table.
84
85    Args:
86      row: (int) row index.
87      col: (int) column index.
88      device_name_filter: Regular expression to filter by device name.
89      node_name_filter: Regular expression to filter by node name.
90      op_type_filter: Regular expression to filter by op type.
91
92    Returns:
93      A debuggre_cli_common.RichLine object representing the content of the
94      cell, potentially with a clickable MenuItem.
95
96    Raises:
97      IndexError: if row index is out of range.
98    """
99    menu_item = None
100    if col == 0:
101      text = self._profile_datum_list[row].node_exec_stats.node_name
102    elif col == 1:
103      text = self._profile_datum_list[row].op_type
104    elif col == 2:
105      text = str(self.formatted_start_time[row])
106    elif col == 3:
107      text = str(self.formatted_op_time[row])
108    elif col == 4:
109      text = str(self.formatted_exec_time[row])
110    elif col == 5:
111      command = "ps"
112      if device_name_filter:
113        command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
114                                 device_name_filter)
115      if node_name_filter:
116        command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, node_name_filter)
117      if op_type_filter:
118        command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, op_type_filter)
119      command += " %s --init_line %d" % (
120          self._profile_datum_list[row].file_path,
121          self._profile_datum_list[row].line_number)
122      menu_item = debugger_cli_common.MenuItem(None, command)
123      text = self._profile_datum_list[row].file_line_func
124    else:
125      raise IndexError("Invalid column index %d." % col)
126
127    return RL(text, font_attr=menu_item)
128
129  def row_count(self):
130    return len(self._profile_datum_list)
131
132  def column_count(self):
133    return len(self._column_names)
134
135  def column_names(self):
136    return self._column_names
137
138  def column_sort_id(self, col):
139    return self._column_sort_ids[col]
140
141
142def _list_profile_filter(
143    profile_datum,
144    node_name_regex,
145    file_path_regex,
146    op_type_regex,
147    op_time_interval,
148    exec_time_interval,
149    min_lineno=-1,
150    max_lineno=-1):
151  """Filter function for list_profile command.
152
153  Args:
154    profile_datum: A `ProfileDatum` object.
155    node_name_regex: Regular expression pattern object to filter by name.
156    file_path_regex: Regular expression pattern object to filter by file path.
157    op_type_regex: Regular expression pattern object to filter by op type.
158    op_time_interval: `Interval` for filtering op time.
159    exec_time_interval: `Interval` for filtering exec time.
160    min_lineno: Lower bound for 1-based line number, inclusive.
161      If <= 0, has no effect.
162    max_lineno: Upper bound for 1-based line number, exclusive.
163      If <= 0, has no effect.
164    # TODO(cais): Maybe filter by function name.
165
166  Returns:
167    True iff profile_datum should be included.
168  """
169  if node_name_regex and not node_name_regex.match(
170      profile_datum.node_exec_stats.node_name):
171    return False
172  if file_path_regex:
173    if (not profile_datum.file_path or
174        not file_path_regex.match(profile_datum.file_path)):
175      return False
176  if (min_lineno > 0 and profile_datum.line_number and
177      profile_datum.line_number < min_lineno):
178    return False
179  if (max_lineno > 0 and profile_datum.line_number and
180      profile_datum.line_number >= max_lineno):
181    return False
182  if (profile_datum.op_type is not None and op_type_regex and
183      not op_type_regex.match(profile_datum.op_type)):
184    return False
185  if op_time_interval is not None and not op_time_interval.contains(
186      profile_datum.op_time):
187    return False
188  if exec_time_interval and not exec_time_interval.contains(
189      profile_datum.node_exec_stats.all_end_rel_micros):
190    return False
191  return True
192
193
194def _list_profile_sort_key(profile_datum, sort_by):
195  """Get a profile_datum property to sort by in list_profile command.
196
197  Args:
198    profile_datum: A `ProfileDatum` object.
199    sort_by: (string) indicates a value to sort by.
200      Must be one of SORT_BY* constants.
201
202  Returns:
203    profile_datum property to sort by.
204  """
205  if sort_by == SORT_OPS_BY_OP_NAME:
206    return profile_datum.node_exec_stats.node_name
207  elif sort_by == SORT_OPS_BY_OP_TYPE:
208    return profile_datum.op_type
209  elif sort_by == SORT_OPS_BY_LINE:
210    return profile_datum.file_line_func
211  elif sort_by == SORT_OPS_BY_OP_TIME:
212    return profile_datum.op_time
213  elif sort_by == SORT_OPS_BY_EXEC_TIME:
214    return profile_datum.node_exec_stats.all_end_rel_micros
215  else:  # sort by start time
216    return profile_datum.node_exec_stats.all_start_micros
217
218
219class ProfileAnalyzer(object):
220  """Analyzer for profiling data."""
221
222  def __init__(self, graph, run_metadata):
223    """ProfileAnalyzer constructor.
224
225    Args:
226      graph: (tf.Graph) Python graph object.
227      run_metadata: A `RunMetadata` protobuf object.
228
229    Raises:
230      ValueError: If run_metadata is None.
231    """
232    self._graph = graph
233    if not run_metadata:
234      raise ValueError("No RunMetadata passed for profile analysis.")
235    self._run_metadata = run_metadata
236    self._arg_parsers = {}
237    ap = argparse.ArgumentParser(
238        description="List nodes profile information.",
239        usage=argparse.SUPPRESS)
240    ap.add_argument(
241        "-d",
242        "--%s" % _DEVICE_NAME_FILTER_FLAG,
243        dest=_DEVICE_NAME_FILTER_FLAG,
244        type=str,
245        default="",
246        help="filter device name by regex.")
247    ap.add_argument(
248        "-n",
249        "--%s" % _NODE_NAME_FILTER_FLAG,
250        dest=_NODE_NAME_FILTER_FLAG,
251        type=str,
252        default="",
253        help="filter node name by regex.")
254    ap.add_argument(
255        "-t",
256        "--%s" % _OP_TYPE_FILTER_FLAG,
257        dest=_OP_TYPE_FILTER_FLAG,
258        type=str,
259        default="",
260        help="filter op type by regex.")
261    # TODO(annarev): allow file filtering at non-stack top position.
262    ap.add_argument(
263        "-f",
264        "--file_path_filter",
265        dest="file_path_filter",
266        type=str,
267        default="",
268        help="filter by file name at the top position of node's creation "
269             "stack that does not belong to TensorFlow library.")
270    ap.add_argument(
271        "--min_lineno",
272        dest="min_lineno",
273        type=int,
274        default=-1,
275        help="(Inclusive) lower bound for 1-based line number in source file. "
276             "If <= 0, has no effect.")
277    ap.add_argument(
278        "--max_lineno",
279        dest="max_lineno",
280        type=int,
281        default=-1,
282        help="(Exclusive) upper bound for 1-based line number in source file. "
283             "If <= 0, has no effect.")
284    ap.add_argument(
285        "-e",
286        "--execution_time",
287        dest="execution_time",
288        type=str,
289        default="",
290        help="Filter by execution time interval "
291             "(includes compute plus pre- and post -processing time). "
292             "Supported units are s, ms and us (default). "
293             "E.g. -e >100s, -e <100, -e [100us,1000ms]")
294    ap.add_argument(
295        "-o",
296        "--op_time",
297        dest="op_time",
298        type=str,
299        default="",
300        help="Filter by op time interval (only includes compute time). "
301             "Supported units are s, ms and us (default). "
302             "E.g. -e >100s, -e <100, -e [100us,1000ms]")
303    ap.add_argument(
304        "-s",
305        "--sort_by",
306        dest="sort_by",
307        type=str,
308        default=SORT_OPS_BY_START_TIME,
309        help=("the field to sort the data by: (%s)" %
310              " | ".join([SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
311                          SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
312                          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE])))
313    ap.add_argument(
314        "-r",
315        "--reverse",
316        dest="reverse",
317        action="store_true",
318        help="sort the data in reverse (descending) order")
319    ap.add_argument(
320        "--time_unit",
321        dest="time_unit",
322        type=str,
323        default=cli_shared.TIME_UNIT_US,
324        help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")")
325
326    self._arg_parsers["list_profile"] = ap
327
328    ap = argparse.ArgumentParser(
329        description="Print a Python source file with line-level profile "
330                    "information",
331        usage=argparse.SUPPRESS)
332    ap.add_argument(
333        "source_file_path",
334        type=str,
335        help="Path to the source_file_path")
336    ap.add_argument(
337        "--cost_type",
338        type=str,
339        choices=["exec_time", "op_time"],
340        default="exec_time",
341        help="Type of cost to display")
342    ap.add_argument(
343        "--time_unit",
344        dest="time_unit",
345        type=str,
346        default=cli_shared.TIME_UNIT_US,
347        help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")")
348    ap.add_argument(
349        "-d",
350        "--%s" % _DEVICE_NAME_FILTER_FLAG,
351        dest=_DEVICE_NAME_FILTER_FLAG,
352        type=str,
353        default="",
354        help="Filter device name by regex.")
355    ap.add_argument(
356        "-n",
357        "--%s" % _NODE_NAME_FILTER_FLAG,
358        dest=_NODE_NAME_FILTER_FLAG,
359        type=str,
360        default="",
361        help="Filter node name by regex.")
362    ap.add_argument(
363        "-t",
364        "--%s" % _OP_TYPE_FILTER_FLAG,
365        dest=_OP_TYPE_FILTER_FLAG,
366        type=str,
367        default="",
368        help="Filter op type by regex.")
369    ap.add_argument(
370        "--init_line",
371        dest="init_line",
372        type=int,
373        default=0,
374        help="The 1-based line number to scroll to initially.")
375
376    self._arg_parsers["print_source"] = ap
377
378  def list_profile(self, args, screen_info=None):
379    """Command handler for list_profile.
380
381    List per-operation profile information.
382
383    Args:
384      args: Command-line arguments, excluding the command prefix, as a list of
385        str.
386      screen_info: Optional dict input containing screen information such as
387        cols.
388
389    Returns:
390      Output text lines as a RichTextLines object.
391    """
392    screen_cols = 80
393    if screen_info and "cols" in screen_info:
394      screen_cols = screen_info["cols"]
395
396    parsed = self._arg_parsers["list_profile"].parse_args(args)
397    op_time_interval = (command_parser.parse_time_interval(parsed.op_time)
398                        if parsed.op_time else None)
399    exec_time_interval = (
400        command_parser.parse_time_interval(parsed.execution_time)
401        if parsed.execution_time else None)
402    node_name_regex = (re.compile(parsed.node_name_filter)
403                       if parsed.node_name_filter else None)
404    file_path_regex = (re.compile(parsed.file_path_filter)
405                       if parsed.file_path_filter else None)
406    op_type_regex = (re.compile(parsed.op_type_filter)
407                     if parsed.op_type_filter else None)
408
409    output = debugger_cli_common.RichTextLines([""])
410    device_name_regex = (re.compile(parsed.device_name_filter)
411                         if parsed.device_name_filter else None)
412    data_generator = self._get_profile_data_generator()
413    device_count = len(self._run_metadata.step_stats.dev_stats)
414    for index in range(device_count):
415      device_stats = self._run_metadata.step_stats.dev_stats[index]
416      if not device_name_regex or device_name_regex.match(device_stats.device):
417        profile_data = [
418            datum for datum in data_generator(device_stats)
419            if _list_profile_filter(
420                datum, node_name_regex, file_path_regex, op_type_regex,
421                op_time_interval, exec_time_interval,
422                min_lineno=parsed.min_lineno, max_lineno=parsed.max_lineno)]
423        profile_data = sorted(
424            profile_data,
425            key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by),
426            reverse=parsed.reverse)
427        output.extend(
428            self._get_list_profile_lines(
429                device_stats.device, index, device_count,
430                profile_data, parsed.sort_by, parsed.reverse, parsed.time_unit,
431                device_name_filter=parsed.device_name_filter,
432                node_name_filter=parsed.node_name_filter,
433                op_type_filter=parsed.op_type_filter,
434                screen_cols=screen_cols))
435    return output
436
437  def _get_profile_data_generator(self):
438    """Get function that generates `ProfileDatum` objects.
439
440    Returns:
441      A function that generates `ProfileDatum` objects.
442    """
443    node_to_file_path = {}
444    node_to_line_number = {}
445    node_to_func_name = {}
446    node_to_op_type = {}
447    for op in self._graph.get_operations():
448      for trace_entry in reversed(op.traceback):
449        file_path = trace_entry[0]
450        line_num = trace_entry[1]
451        func_name = trace_entry[2]
452        if not source_utils.guess_is_tensorflow_py_library(file_path):
453          break
454      node_to_file_path[op.name] = file_path
455      node_to_line_number[op.name] = line_num
456      node_to_func_name[op.name] = func_name
457      node_to_op_type[op.name] = op.type
458
459    def profile_data_generator(device_step_stats):
460      for node_stats in device_step_stats.node_stats:
461        if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK":
462          continue
463        yield profiling.ProfileDatum(
464            device_step_stats.device,
465            node_stats,
466            node_to_file_path.get(node_stats.node_name, ""),
467            node_to_line_number.get(node_stats.node_name, 0),
468            node_to_func_name.get(node_stats.node_name, ""),
469            node_to_op_type.get(node_stats.node_name, ""))
470    return profile_data_generator
471
472  def _get_list_profile_lines(
473      self, device_name, device_index, device_count,
474      profile_datum_list, sort_by, sort_reverse, time_unit,
475      device_name_filter=None, node_name_filter=None, op_type_filter=None,
476      screen_cols=80):
477    """Get `RichTextLines` object for list_profile command for a given device.
478
479    Args:
480      device_name: (string) Device name.
481      device_index: (int) Device index.
482      device_count: (int) Number of devices.
483      profile_datum_list: List of `ProfileDatum` objects.
484      sort_by: (string) Identifier of column to sort. Sort identifier
485          must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
486          SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
487      sort_reverse: (bool) Whether to sort in descending instead of default
488          (ascending) order.
489      time_unit: time unit, must be in cli_shared.TIME_UNITS.
490      device_name_filter: Regular expression to filter by device name.
491      node_name_filter: Regular expression to filter by node name.
492      op_type_filter: Regular expression to filter by op type.
493      screen_cols: (int) Number of columns available on the screen (i.e.,
494        available screen width).
495
496    Returns:
497      `RichTextLines` object containing a table that displays profiling
498      information for each op.
499    """
500    profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)
501
502    # Calculate total time early to calculate column widths.
503    total_op_time = sum(datum.op_time for datum in profile_datum_list)
504    total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
505                          for datum in profile_datum_list)
506    device_total_row = [
507        "Device Total", "",
508        cli_shared.time_to_readable_str(total_op_time,
509                                        force_time_unit=time_unit),
510        cli_shared.time_to_readable_str(total_exec_time,
511                                        force_time_unit=time_unit)]
512
513    # Calculate column widths.
514    column_widths = [
515        len(column_name) for column_name in profile_data.column_names()]
516    for col in range(len(device_total_row)):
517      column_widths[col] = max(column_widths[col], len(device_total_row[col]))
518    for col in range(len(column_widths)):
519      for row in range(profile_data.row_count()):
520        column_widths[col] = max(
521            column_widths[col], len(profile_data.value(
522                row,
523                col,
524                device_name_filter=device_name_filter,
525                node_name_filter=node_name_filter,
526                op_type_filter=op_type_filter)))
527      column_widths[col] += 2  # add margin between columns
528
529    # Add device name.
530    output = [RL("-" * screen_cols)]
531    device_row = "Device %d of %d: %s" % (
532        device_index + 1, device_count, device_name)
533    output.append(RL(device_row))
534    output.append(RL())
535
536    # Add headers.
537    base_command = "list_profile"
538    row = RL()
539    for col in range(profile_data.column_count()):
540      column_name = profile_data.column_names()[col]
541      sort_id = profile_data.column_sort_id(col)
542      command = "%s -s %s" % (base_command, sort_id)
543      if sort_by == sort_id and not sort_reverse:
544        command += " -r"
545      head_menu_item = debugger_cli_common.MenuItem(None, command)
546      row += RL(column_name, font_attr=[head_menu_item, "bold"])
547      row += RL(" " * (column_widths[col] - len(column_name)))
548
549    output.append(row)
550
551    # Add data rows.
552    for row in range(profile_data.row_count()):
553      new_row = RL()
554      for col in range(profile_data.column_count()):
555        new_cell = profile_data.value(
556            row,
557            col,
558            device_name_filter=device_name_filter,
559            node_name_filter=node_name_filter,
560            op_type_filter=op_type_filter)
561        new_row += new_cell
562        new_row += RL(" " * (column_widths[col] - len(new_cell)))
563      output.append(new_row)
564
565    # Add stat totals.
566    row_str = ""
567    for width, row in zip(column_widths, device_total_row):
568      row_str += ("{:<%d}" % width).format(row)
569    output.append(RL())
570    output.append(RL(row_str))
571    return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
572
573  def _measure_list_profile_column_widths(self, profile_data):
574    """Determine the maximum column widths for each data list.
575
576    Args:
577      profile_data: list of ProfileDatum objects.
578
579    Returns:
580      List of column widths in the same order as columns in data.
581    """
582    num_columns = len(profile_data.column_names())
583    widths = [len(column_name) for column_name in profile_data.column_names()]
584    for row in range(profile_data.row_count()):
585      for col in range(num_columns):
586        widths[col] = max(
587            widths[col], len(str(profile_data.row_values(row)[col])) + 2)
588    return widths
589
590  _LINE_COST_ATTR = cli_shared.COLOR_CYAN
591  _LINE_NUM_ATTR = cli_shared.COLOR_YELLOW
592  _NUM_NODES_HEAD = "#nodes"
593  _NUM_EXECS_SUB_HEAD = "(#execs)"
594  _LINENO_HEAD = "lineno"
595  _SOURCE_HEAD = "source"
596
597  def print_source(self, args, screen_info=None):
598    """Print a Python source file with line-level profile information.
599
600    Args:
601      args: Command-line arguments, excluding the command prefix, as a list of
602        str.
603      screen_info: Optional dict input containing screen information such as
604        cols.
605
606    Returns:
607      Output text lines as a RichTextLines object.
608    """
609    del screen_info
610
611    parsed = self._arg_parsers["print_source"].parse_args(args)
612
613    device_name_regex = (re.compile(parsed.device_name_filter)
614                         if parsed.device_name_filter else None)
615
616    profile_data = []
617    data_generator = self._get_profile_data_generator()
618    device_count = len(self._run_metadata.step_stats.dev_stats)
619    for index in range(device_count):
620      device_stats = self._run_metadata.step_stats.dev_stats[index]
621      if device_name_regex and not device_name_regex.match(device_stats.device):
622        continue
623      profile_data.extend(data_generator(device_stats))
624
625    source_annotation = source_utils.annotate_source_against_profile(
626        profile_data,
627        os.path.expanduser(parsed.source_file_path),
628        node_name_filter=parsed.node_name_filter,
629        op_type_filter=parsed.op_type_filter)
630    if not source_annotation:
631      return debugger_cli_common.RichTextLines(
632          ["The source file %s does not contain any profile information for "
633           "the previous Session run under the following "
634           "filters:" % parsed.source_file_path,
635           "  --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
636           "  --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
637           "  --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])
638
639    max_total_cost = 0
640    for line_index in source_annotation:
641      total_cost = self._get_total_cost(source_annotation[line_index],
642                                        parsed.cost_type)
643      max_total_cost = max(max_total_cost, total_cost)
644
645    source_lines, line_num_width = source_utils.load_source(
646        parsed.source_file_path)
647
648    cost_bar_max_length = 10
649    total_cost_head = parsed.cost_type
650    column_widths = {
651        "cost_bar": cost_bar_max_length + 3,
652        "total_cost": len(total_cost_head) + 3,
653        "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
654        "line_number": line_num_width,
655    }
656
657    head = RL(
658        " " * column_widths["cost_bar"] +
659        total_cost_head +
660        " " * (column_widths["total_cost"] - len(total_cost_head)) +
661        self._NUM_NODES_HEAD +
662        " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
663        font_attr=self._LINE_COST_ATTR)
664    head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
665    sub_head = RL(
666        " " * (column_widths["cost_bar"] +
667               column_widths["total_cost"]) +
668        self._NUM_EXECS_SUB_HEAD +
669        " " * (column_widths["num_nodes_execs"] -
670               len(self._NUM_EXECS_SUB_HEAD)) +
671        " " * column_widths["line_number"],
672        font_attr=self._LINE_COST_ATTR)
673    sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
674    lines = [head, sub_head]
675
676    output_annotations = {}
677    for i, line in enumerate(source_lines):
678      lineno = i + 1
679      if lineno in source_annotation:
680        annotation = source_annotation[lineno]
681        cost_bar = self._render_normalized_cost_bar(
682            self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
683            cost_bar_max_length)
684        annotated_line = cost_bar
685        annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))
686
687        total_cost = RL(cli_shared.time_to_readable_str(
688            self._get_total_cost(annotation, parsed.cost_type),
689            force_time_unit=parsed.time_unit),
690                        font_attr=self._LINE_COST_ATTR)
691        total_cost += " " * (column_widths["total_cost"] - len(total_cost))
692        annotated_line += total_cost
693
694        file_path_filter = re.escape(parsed.source_file_path) + "$"
695        command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
696            file_path_filter, lineno, lineno + 1)
697        if parsed.device_name_filter:
698          command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
699                                   parsed.device_name_filter)
700        if parsed.node_name_filter:
701          command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
702                                   parsed.node_name_filter)
703        if parsed.op_type_filter:
704          command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
705                                   parsed.op_type_filter)
706        menu_item = debugger_cli_common.MenuItem(None, command)
707        num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
708                                         annotation.node_exec_count),
709                             font_attr=[self._LINE_COST_ATTR, menu_item])
710        num_nodes_execs += " " * (
711            column_widths["num_nodes_execs"] - len(num_nodes_execs))
712        annotated_line += num_nodes_execs
713      else:
714        annotated_line = RL(
715            " " * sum(column_widths[col_name] for col_name in column_widths
716                      if col_name != "line_number"))
717
718      line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
719      line_num_column += " " * (
720          column_widths["line_number"] - len(line_num_column))
721      annotated_line += line_num_column
722      annotated_line += line
723      lines.append(annotated_line)
724
725      if parsed.init_line == lineno:
726        output_annotations[
727            debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1
728
729    return debugger_cli_common.rich_text_lines_from_rich_line_list(
730        lines, annotations=output_annotations)
731
732  def _get_total_cost(self, aggregated_profile, cost_type):
733    if cost_type == "exec_time":
734      return aggregated_profile.total_exec_time
735    elif cost_type == "op_time":
736      return aggregated_profile.total_op_time
737    else:
738      raise ValueError("Unsupported cost type: %s" % cost_type)
739
740  def _render_normalized_cost_bar(self, cost, max_cost, length):
741    """Render a text bar representing a normalized cost.
742
743    Args:
744      cost: the absolute value of the cost.
745      max_cost: the maximum cost value to normalize the absolute cost with.
746      length: (int) length of the cost bar, in number of characters, excluding
747        the brackets on the two ends.
748
749    Returns:
750      An instance of debugger_cli_common.RichTextLine.
751    """
752    num_ticks = int(np.ceil(float(cost) / max_cost * length))
753    num_ticks = num_ticks or 1  # Minimum is 1 tick.
754    output = RL("[", font_attr=self._LINE_COST_ATTR)
755    output += RL("|" * num_ticks + " " * (length - num_ticks),
756                 font_attr=["bold", self._LINE_COST_ATTR])
757    output += RL("]", font_attr=self._LINE_COST_ATTR)
758    return output
759
760  def get_help(self, handler_name):
761    return self._arg_parsers[handler_name].format_help()
762
763
764def create_profiler_ui(graph,
765                       run_metadata,
766                       ui_type="curses",
767                       on_ui_exit=None,
768                       config=None):
769  """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
770
771  Args:
772    graph: Python `Graph` object.
773    run_metadata: A `RunMetadata` protobuf object.
774    ui_type: (str) requested UI type, e.g., "curses", "readline".
775    on_ui_exit: (`Callable`) the callback to be called when the UI exits.
776    config: An instance of `cli_config.CLIConfig`.
777
778  Returns:
779    (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
780      commands and tab-completions registered.
781  """
782  del config  # Currently unused.
783
784  analyzer = ProfileAnalyzer(graph, run_metadata)
785
786  cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
787  cli.register_command_handler(
788      "list_profile",
789      analyzer.list_profile,
790      analyzer.get_help("list_profile"),
791      prefix_aliases=["lp"])
792  cli.register_command_handler(
793      "print_source",
794      analyzer.print_source,
795      analyzer.get_help("print_source"),
796      prefix_aliases=["ps"])
797
798  return cli
799