1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Formats and displays profiling information.""" 16 17import argparse 18import os 19import re 20 21import numpy as np 22 23from tensorflow.python.debug.cli import cli_shared 24from tensorflow.python.debug.cli import command_parser 25from tensorflow.python.debug.cli import debugger_cli_common 26from tensorflow.python.debug.cli import ui_factory 27from tensorflow.python.debug.lib import profiling 28from tensorflow.python.debug.lib import source_utils 29 30RL = debugger_cli_common.RichLine 31 32SORT_OPS_BY_OP_NAME = "node" 33SORT_OPS_BY_OP_TYPE = "op_type" 34SORT_OPS_BY_OP_TIME = "op_time" 35SORT_OPS_BY_EXEC_TIME = "exec_time" 36SORT_OPS_BY_START_TIME = "start_time" 37SORT_OPS_BY_LINE = "line" 38 39_DEVICE_NAME_FILTER_FLAG = "device_name_filter" 40_NODE_NAME_FILTER_FLAG = "node_name_filter" 41_OP_TYPE_FILTER_FLAG = "op_type_filter" 42 43 44class ProfileDataTableView(object): 45 """Table View of profiling data.""" 46 47 def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US): 48 """Constructor. 49 50 Args: 51 profile_datum_list: List of `ProfileDatum` objects. 52 time_unit: must be in cli_shared.TIME_UNITS. 53 """ 54 self._profile_datum_list = profile_datum_list 55 self.formatted_start_time = [ 56 datum.start_time for datum in profile_datum_list] 57 self.formatted_op_time = [ 58 cli_shared.time_to_readable_str(datum.op_time, 59 force_time_unit=time_unit) 60 for datum in profile_datum_list] 61 self.formatted_exec_time = [ 62 cli_shared.time_to_readable_str( 63 datum.node_exec_stats.all_end_rel_micros, 64 force_time_unit=time_unit) 65 for datum in profile_datum_list] 66 67 self._column_names = ["Node", 68 "Op Type", 69 "Start Time (us)", 70 "Op Time (%s)" % time_unit, 71 "Exec Time (%s)" % time_unit, 72 "Filename:Lineno(function)"] 73 self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 74 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME, 75 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE] 76 77 def value(self, 78 row, 79 col, 80 device_name_filter=None, 81 node_name_filter=None, 82 op_type_filter=None): 83 """Get the content of a cell of the table. 84 85 Args: 86 row: (int) row index. 87 col: (int) column index. 88 device_name_filter: Regular expression to filter by device name. 89 node_name_filter: Regular expression to filter by node name. 90 op_type_filter: Regular expression to filter by op type. 91 92 Returns: 93 A debuggre_cli_common.RichLine object representing the content of the 94 cell, potentially with a clickable MenuItem. 95 96 Raises: 97 IndexError: if row index is out of range. 98 """ 99 menu_item = None 100 if col == 0: 101 text = self._profile_datum_list[row].node_exec_stats.node_name 102 elif col == 1: 103 text = self._profile_datum_list[row].op_type 104 elif col == 2: 105 text = str(self.formatted_start_time[row]) 106 elif col == 3: 107 text = str(self.formatted_op_time[row]) 108 elif col == 4: 109 text = str(self.formatted_exec_time[row]) 110 elif col == 5: 111 command = "ps" 112 if device_name_filter: 113 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG, 114 device_name_filter) 115 if node_name_filter: 116 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, node_name_filter) 117 if op_type_filter: 118 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, op_type_filter) 119 command += " %s --init_line %d" % ( 120 self._profile_datum_list[row].file_path, 121 self._profile_datum_list[row].line_number) 122 menu_item = debugger_cli_common.MenuItem(None, command) 123 text = self._profile_datum_list[row].file_line_func 124 else: 125 raise IndexError("Invalid column index %d." % col) 126 127 return RL(text, font_attr=menu_item) 128 129 def row_count(self): 130 return len(self._profile_datum_list) 131 132 def column_count(self): 133 return len(self._column_names) 134 135 def column_names(self): 136 return self._column_names 137 138 def column_sort_id(self, col): 139 return self._column_sort_ids[col] 140 141 142def _list_profile_filter( 143 profile_datum, 144 node_name_regex, 145 file_path_regex, 146 op_type_regex, 147 op_time_interval, 148 exec_time_interval, 149 min_lineno=-1, 150 max_lineno=-1): 151 """Filter function for list_profile command. 152 153 Args: 154 profile_datum: A `ProfileDatum` object. 155 node_name_regex: Regular expression pattern object to filter by name. 156 file_path_regex: Regular expression pattern object to filter by file path. 157 op_type_regex: Regular expression pattern object to filter by op type. 158 op_time_interval: `Interval` for filtering op time. 159 exec_time_interval: `Interval` for filtering exec time. 160 min_lineno: Lower bound for 1-based line number, inclusive. 161 If <= 0, has no effect. 162 max_lineno: Upper bound for 1-based line number, exclusive. 163 If <= 0, has no effect. 164 # TODO(cais): Maybe filter by function name. 165 166 Returns: 167 True iff profile_datum should be included. 168 """ 169 if node_name_regex and not node_name_regex.match( 170 profile_datum.node_exec_stats.node_name): 171 return False 172 if file_path_regex: 173 if (not profile_datum.file_path or 174 not file_path_regex.match(profile_datum.file_path)): 175 return False 176 if (min_lineno > 0 and profile_datum.line_number and 177 profile_datum.line_number < min_lineno): 178 return False 179 if (max_lineno > 0 and profile_datum.line_number and 180 profile_datum.line_number >= max_lineno): 181 return False 182 if (profile_datum.op_type is not None and op_type_regex and 183 not op_type_regex.match(profile_datum.op_type)): 184 return False 185 if op_time_interval is not None and not op_time_interval.contains( 186 profile_datum.op_time): 187 return False 188 if exec_time_interval and not exec_time_interval.contains( 189 profile_datum.node_exec_stats.all_end_rel_micros): 190 return False 191 return True 192 193 194def _list_profile_sort_key(profile_datum, sort_by): 195 """Get a profile_datum property to sort by in list_profile command. 196 197 Args: 198 profile_datum: A `ProfileDatum` object. 199 sort_by: (string) indicates a value to sort by. 200 Must be one of SORT_BY* constants. 201 202 Returns: 203 profile_datum property to sort by. 204 """ 205 if sort_by == SORT_OPS_BY_OP_NAME: 206 return profile_datum.node_exec_stats.node_name 207 elif sort_by == SORT_OPS_BY_OP_TYPE: 208 return profile_datum.op_type 209 elif sort_by == SORT_OPS_BY_LINE: 210 return profile_datum.file_line_func 211 elif sort_by == SORT_OPS_BY_OP_TIME: 212 return profile_datum.op_time 213 elif sort_by == SORT_OPS_BY_EXEC_TIME: 214 return profile_datum.node_exec_stats.all_end_rel_micros 215 else: # sort by start time 216 return profile_datum.node_exec_stats.all_start_micros 217 218 219class ProfileAnalyzer(object): 220 """Analyzer for profiling data.""" 221 222 def __init__(self, graph, run_metadata): 223 """ProfileAnalyzer constructor. 224 225 Args: 226 graph: (tf.Graph) Python graph object. 227 run_metadata: A `RunMetadata` protobuf object. 228 229 Raises: 230 ValueError: If run_metadata is None. 231 """ 232 self._graph = graph 233 if not run_metadata: 234 raise ValueError("No RunMetadata passed for profile analysis.") 235 self._run_metadata = run_metadata 236 self._arg_parsers = {} 237 ap = argparse.ArgumentParser( 238 description="List nodes profile information.", 239 usage=argparse.SUPPRESS) 240 ap.add_argument( 241 "-d", 242 "--%s" % _DEVICE_NAME_FILTER_FLAG, 243 dest=_DEVICE_NAME_FILTER_FLAG, 244 type=str, 245 default="", 246 help="filter device name by regex.") 247 ap.add_argument( 248 "-n", 249 "--%s" % _NODE_NAME_FILTER_FLAG, 250 dest=_NODE_NAME_FILTER_FLAG, 251 type=str, 252 default="", 253 help="filter node name by regex.") 254 ap.add_argument( 255 "-t", 256 "--%s" % _OP_TYPE_FILTER_FLAG, 257 dest=_OP_TYPE_FILTER_FLAG, 258 type=str, 259 default="", 260 help="filter op type by regex.") 261 # TODO(annarev): allow file filtering at non-stack top position. 262 ap.add_argument( 263 "-f", 264 "--file_path_filter", 265 dest="file_path_filter", 266 type=str, 267 default="", 268 help="filter by file name at the top position of node's creation " 269 "stack that does not belong to TensorFlow library.") 270 ap.add_argument( 271 "--min_lineno", 272 dest="min_lineno", 273 type=int, 274 default=-1, 275 help="(Inclusive) lower bound for 1-based line number in source file. " 276 "If <= 0, has no effect.") 277 ap.add_argument( 278 "--max_lineno", 279 dest="max_lineno", 280 type=int, 281 default=-1, 282 help="(Exclusive) upper bound for 1-based line number in source file. " 283 "If <= 0, has no effect.") 284 ap.add_argument( 285 "-e", 286 "--execution_time", 287 dest="execution_time", 288 type=str, 289 default="", 290 help="Filter by execution time interval " 291 "(includes compute plus pre- and post -processing time). " 292 "Supported units are s, ms and us (default). " 293 "E.g. -e >100s, -e <100, -e [100us,1000ms]") 294 ap.add_argument( 295 "-o", 296 "--op_time", 297 dest="op_time", 298 type=str, 299 default="", 300 help="Filter by op time interval (only includes compute time). " 301 "Supported units are s, ms and us (default). " 302 "E.g. -e >100s, -e <100, -e [100us,1000ms]") 303 ap.add_argument( 304 "-s", 305 "--sort_by", 306 dest="sort_by", 307 type=str, 308 default=SORT_OPS_BY_START_TIME, 309 help=("the field to sort the data by: (%s)" % 310 " | ".join([SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 311 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME, 312 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]))) 313 ap.add_argument( 314 "-r", 315 "--reverse", 316 dest="reverse", 317 action="store_true", 318 help="sort the data in reverse (descending) order") 319 ap.add_argument( 320 "--time_unit", 321 dest="time_unit", 322 type=str, 323 default=cli_shared.TIME_UNIT_US, 324 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")") 325 326 self._arg_parsers["list_profile"] = ap 327 328 ap = argparse.ArgumentParser( 329 description="Print a Python source file with line-level profile " 330 "information", 331 usage=argparse.SUPPRESS) 332 ap.add_argument( 333 "source_file_path", 334 type=str, 335 help="Path to the source_file_path") 336 ap.add_argument( 337 "--cost_type", 338 type=str, 339 choices=["exec_time", "op_time"], 340 default="exec_time", 341 help="Type of cost to display") 342 ap.add_argument( 343 "--time_unit", 344 dest="time_unit", 345 type=str, 346 default=cli_shared.TIME_UNIT_US, 347 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")") 348 ap.add_argument( 349 "-d", 350 "--%s" % _DEVICE_NAME_FILTER_FLAG, 351 dest=_DEVICE_NAME_FILTER_FLAG, 352 type=str, 353 default="", 354 help="Filter device name by regex.") 355 ap.add_argument( 356 "-n", 357 "--%s" % _NODE_NAME_FILTER_FLAG, 358 dest=_NODE_NAME_FILTER_FLAG, 359 type=str, 360 default="", 361 help="Filter node name by regex.") 362 ap.add_argument( 363 "-t", 364 "--%s" % _OP_TYPE_FILTER_FLAG, 365 dest=_OP_TYPE_FILTER_FLAG, 366 type=str, 367 default="", 368 help="Filter op type by regex.") 369 ap.add_argument( 370 "--init_line", 371 dest="init_line", 372 type=int, 373 default=0, 374 help="The 1-based line number to scroll to initially.") 375 376 self._arg_parsers["print_source"] = ap 377 378 def list_profile(self, args, screen_info=None): 379 """Command handler for list_profile. 380 381 List per-operation profile information. 382 383 Args: 384 args: Command-line arguments, excluding the command prefix, as a list of 385 str. 386 screen_info: Optional dict input containing screen information such as 387 cols. 388 389 Returns: 390 Output text lines as a RichTextLines object. 391 """ 392 screen_cols = 80 393 if screen_info and "cols" in screen_info: 394 screen_cols = screen_info["cols"] 395 396 parsed = self._arg_parsers["list_profile"].parse_args(args) 397 op_time_interval = (command_parser.parse_time_interval(parsed.op_time) 398 if parsed.op_time else None) 399 exec_time_interval = ( 400 command_parser.parse_time_interval(parsed.execution_time) 401 if parsed.execution_time else None) 402 node_name_regex = (re.compile(parsed.node_name_filter) 403 if parsed.node_name_filter else None) 404 file_path_regex = (re.compile(parsed.file_path_filter) 405 if parsed.file_path_filter else None) 406 op_type_regex = (re.compile(parsed.op_type_filter) 407 if parsed.op_type_filter else None) 408 409 output = debugger_cli_common.RichTextLines([""]) 410 device_name_regex = (re.compile(parsed.device_name_filter) 411 if parsed.device_name_filter else None) 412 data_generator = self._get_profile_data_generator() 413 device_count = len(self._run_metadata.step_stats.dev_stats) 414 for index in range(device_count): 415 device_stats = self._run_metadata.step_stats.dev_stats[index] 416 if not device_name_regex or device_name_regex.match(device_stats.device): 417 profile_data = [ 418 datum for datum in data_generator(device_stats) 419 if _list_profile_filter( 420 datum, node_name_regex, file_path_regex, op_type_regex, 421 op_time_interval, exec_time_interval, 422 min_lineno=parsed.min_lineno, max_lineno=parsed.max_lineno)] 423 profile_data = sorted( 424 profile_data, 425 key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by), 426 reverse=parsed.reverse) 427 output.extend( 428 self._get_list_profile_lines( 429 device_stats.device, index, device_count, 430 profile_data, parsed.sort_by, parsed.reverse, parsed.time_unit, 431 device_name_filter=parsed.device_name_filter, 432 node_name_filter=parsed.node_name_filter, 433 op_type_filter=parsed.op_type_filter, 434 screen_cols=screen_cols)) 435 return output 436 437 def _get_profile_data_generator(self): 438 """Get function that generates `ProfileDatum` objects. 439 440 Returns: 441 A function that generates `ProfileDatum` objects. 442 """ 443 node_to_file_path = {} 444 node_to_line_number = {} 445 node_to_func_name = {} 446 node_to_op_type = {} 447 for op in self._graph.get_operations(): 448 for trace_entry in reversed(op.traceback): 449 file_path = trace_entry[0] 450 line_num = trace_entry[1] 451 func_name = trace_entry[2] 452 if not source_utils.guess_is_tensorflow_py_library(file_path): 453 break 454 node_to_file_path[op.name] = file_path 455 node_to_line_number[op.name] = line_num 456 node_to_func_name[op.name] = func_name 457 node_to_op_type[op.name] = op.type 458 459 def profile_data_generator(device_step_stats): 460 for node_stats in device_step_stats.node_stats: 461 if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK": 462 continue 463 yield profiling.ProfileDatum( 464 device_step_stats.device, 465 node_stats, 466 node_to_file_path.get(node_stats.node_name, ""), 467 node_to_line_number.get(node_stats.node_name, 0), 468 node_to_func_name.get(node_stats.node_name, ""), 469 node_to_op_type.get(node_stats.node_name, "")) 470 return profile_data_generator 471 472 def _get_list_profile_lines( 473 self, device_name, device_index, device_count, 474 profile_datum_list, sort_by, sort_reverse, time_unit, 475 device_name_filter=None, node_name_filter=None, op_type_filter=None, 476 screen_cols=80): 477 """Get `RichTextLines` object for list_profile command for a given device. 478 479 Args: 480 device_name: (string) Device name. 481 device_index: (int) Device index. 482 device_count: (int) Number of devices. 483 profile_datum_list: List of `ProfileDatum` objects. 484 sort_by: (string) Identifier of column to sort. Sort identifier 485 must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 486 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE. 487 sort_reverse: (bool) Whether to sort in descending instead of default 488 (ascending) order. 489 time_unit: time unit, must be in cli_shared.TIME_UNITS. 490 device_name_filter: Regular expression to filter by device name. 491 node_name_filter: Regular expression to filter by node name. 492 op_type_filter: Regular expression to filter by op type. 493 screen_cols: (int) Number of columns available on the screen (i.e., 494 available screen width). 495 496 Returns: 497 `RichTextLines` object containing a table that displays profiling 498 information for each op. 499 """ 500 profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit) 501 502 # Calculate total time early to calculate column widths. 503 total_op_time = sum(datum.op_time for datum in profile_datum_list) 504 total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros 505 for datum in profile_datum_list) 506 device_total_row = [ 507 "Device Total", "", 508 cli_shared.time_to_readable_str(total_op_time, 509 force_time_unit=time_unit), 510 cli_shared.time_to_readable_str(total_exec_time, 511 force_time_unit=time_unit)] 512 513 # Calculate column widths. 514 column_widths = [ 515 len(column_name) for column_name in profile_data.column_names()] 516 for col in range(len(device_total_row)): 517 column_widths[col] = max(column_widths[col], len(device_total_row[col])) 518 for col in range(len(column_widths)): 519 for row in range(profile_data.row_count()): 520 column_widths[col] = max( 521 column_widths[col], len(profile_data.value( 522 row, 523 col, 524 device_name_filter=device_name_filter, 525 node_name_filter=node_name_filter, 526 op_type_filter=op_type_filter))) 527 column_widths[col] += 2 # add margin between columns 528 529 # Add device name. 530 output = [RL("-" * screen_cols)] 531 device_row = "Device %d of %d: %s" % ( 532 device_index + 1, device_count, device_name) 533 output.append(RL(device_row)) 534 output.append(RL()) 535 536 # Add headers. 537 base_command = "list_profile" 538 row = RL() 539 for col in range(profile_data.column_count()): 540 column_name = profile_data.column_names()[col] 541 sort_id = profile_data.column_sort_id(col) 542 command = "%s -s %s" % (base_command, sort_id) 543 if sort_by == sort_id and not sort_reverse: 544 command += " -r" 545 head_menu_item = debugger_cli_common.MenuItem(None, command) 546 row += RL(column_name, font_attr=[head_menu_item, "bold"]) 547 row += RL(" " * (column_widths[col] - len(column_name))) 548 549 output.append(row) 550 551 # Add data rows. 552 for row in range(profile_data.row_count()): 553 new_row = RL() 554 for col in range(profile_data.column_count()): 555 new_cell = profile_data.value( 556 row, 557 col, 558 device_name_filter=device_name_filter, 559 node_name_filter=node_name_filter, 560 op_type_filter=op_type_filter) 561 new_row += new_cell 562 new_row += RL(" " * (column_widths[col] - len(new_cell))) 563 output.append(new_row) 564 565 # Add stat totals. 566 row_str = "" 567 for width, row in zip(column_widths, device_total_row): 568 row_str += ("{:<%d}" % width).format(row) 569 output.append(RL()) 570 output.append(RL(row_str)) 571 return debugger_cli_common.rich_text_lines_from_rich_line_list(output) 572 573 def _measure_list_profile_column_widths(self, profile_data): 574 """Determine the maximum column widths for each data list. 575 576 Args: 577 profile_data: list of ProfileDatum objects. 578 579 Returns: 580 List of column widths in the same order as columns in data. 581 """ 582 num_columns = len(profile_data.column_names()) 583 widths = [len(column_name) for column_name in profile_data.column_names()] 584 for row in range(profile_data.row_count()): 585 for col in range(num_columns): 586 widths[col] = max( 587 widths[col], len(str(profile_data.row_values(row)[col])) + 2) 588 return widths 589 590 _LINE_COST_ATTR = cli_shared.COLOR_CYAN 591 _LINE_NUM_ATTR = cli_shared.COLOR_YELLOW 592 _NUM_NODES_HEAD = "#nodes" 593 _NUM_EXECS_SUB_HEAD = "(#execs)" 594 _LINENO_HEAD = "lineno" 595 _SOURCE_HEAD = "source" 596 597 def print_source(self, args, screen_info=None): 598 """Print a Python source file with line-level profile information. 599 600 Args: 601 args: Command-line arguments, excluding the command prefix, as a list of 602 str. 603 screen_info: Optional dict input containing screen information such as 604 cols. 605 606 Returns: 607 Output text lines as a RichTextLines object. 608 """ 609 del screen_info 610 611 parsed = self._arg_parsers["print_source"].parse_args(args) 612 613 device_name_regex = (re.compile(parsed.device_name_filter) 614 if parsed.device_name_filter else None) 615 616 profile_data = [] 617 data_generator = self._get_profile_data_generator() 618 device_count = len(self._run_metadata.step_stats.dev_stats) 619 for index in range(device_count): 620 device_stats = self._run_metadata.step_stats.dev_stats[index] 621 if device_name_regex and not device_name_regex.match(device_stats.device): 622 continue 623 profile_data.extend(data_generator(device_stats)) 624 625 source_annotation = source_utils.annotate_source_against_profile( 626 profile_data, 627 os.path.expanduser(parsed.source_file_path), 628 node_name_filter=parsed.node_name_filter, 629 op_type_filter=parsed.op_type_filter) 630 if not source_annotation: 631 return debugger_cli_common.RichTextLines( 632 ["The source file %s does not contain any profile information for " 633 "the previous Session run under the following " 634 "filters:" % parsed.source_file_path, 635 " --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter), 636 " --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter), 637 " --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)]) 638 639 max_total_cost = 0 640 for line_index in source_annotation: 641 total_cost = self._get_total_cost(source_annotation[line_index], 642 parsed.cost_type) 643 max_total_cost = max(max_total_cost, total_cost) 644 645 source_lines, line_num_width = source_utils.load_source( 646 parsed.source_file_path) 647 648 cost_bar_max_length = 10 649 total_cost_head = parsed.cost_type 650 column_widths = { 651 "cost_bar": cost_bar_max_length + 3, 652 "total_cost": len(total_cost_head) + 3, 653 "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1, 654 "line_number": line_num_width, 655 } 656 657 head = RL( 658 " " * column_widths["cost_bar"] + 659 total_cost_head + 660 " " * (column_widths["total_cost"] - len(total_cost_head)) + 661 self._NUM_NODES_HEAD + 662 " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)), 663 font_attr=self._LINE_COST_ATTR) 664 head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR) 665 sub_head = RL( 666 " " * (column_widths["cost_bar"] + 667 column_widths["total_cost"]) + 668 self._NUM_EXECS_SUB_HEAD + 669 " " * (column_widths["num_nodes_execs"] - 670 len(self._NUM_EXECS_SUB_HEAD)) + 671 " " * column_widths["line_number"], 672 font_attr=self._LINE_COST_ATTR) 673 sub_head += RL(self._SOURCE_HEAD, font_attr="bold") 674 lines = [head, sub_head] 675 676 output_annotations = {} 677 for i, line in enumerate(source_lines): 678 lineno = i + 1 679 if lineno in source_annotation: 680 annotation = source_annotation[lineno] 681 cost_bar = self._render_normalized_cost_bar( 682 self._get_total_cost(annotation, parsed.cost_type), max_total_cost, 683 cost_bar_max_length) 684 annotated_line = cost_bar 685 annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar)) 686 687 total_cost = RL(cli_shared.time_to_readable_str( 688 self._get_total_cost(annotation, parsed.cost_type), 689 force_time_unit=parsed.time_unit), 690 font_attr=self._LINE_COST_ATTR) 691 total_cost += " " * (column_widths["total_cost"] - len(total_cost)) 692 annotated_line += total_cost 693 694 file_path_filter = re.escape(parsed.source_file_path) + "$" 695 command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % ( 696 file_path_filter, lineno, lineno + 1) 697 if parsed.device_name_filter: 698 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG, 699 parsed.device_name_filter) 700 if parsed.node_name_filter: 701 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, 702 parsed.node_name_filter) 703 if parsed.op_type_filter: 704 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, 705 parsed.op_type_filter) 706 menu_item = debugger_cli_common.MenuItem(None, command) 707 num_nodes_execs = RL("%d(%d)" % (annotation.node_count, 708 annotation.node_exec_count), 709 font_attr=[self._LINE_COST_ATTR, menu_item]) 710 num_nodes_execs += " " * ( 711 column_widths["num_nodes_execs"] - len(num_nodes_execs)) 712 annotated_line += num_nodes_execs 713 else: 714 annotated_line = RL( 715 " " * sum(column_widths[col_name] for col_name in column_widths 716 if col_name != "line_number")) 717 718 line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR) 719 line_num_column += " " * ( 720 column_widths["line_number"] - len(line_num_column)) 721 annotated_line += line_num_column 722 annotated_line += line 723 lines.append(annotated_line) 724 725 if parsed.init_line == lineno: 726 output_annotations[ 727 debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1 728 729 return debugger_cli_common.rich_text_lines_from_rich_line_list( 730 lines, annotations=output_annotations) 731 732 def _get_total_cost(self, aggregated_profile, cost_type): 733 if cost_type == "exec_time": 734 return aggregated_profile.total_exec_time 735 elif cost_type == "op_time": 736 return aggregated_profile.total_op_time 737 else: 738 raise ValueError("Unsupported cost type: %s" % cost_type) 739 740 def _render_normalized_cost_bar(self, cost, max_cost, length): 741 """Render a text bar representing a normalized cost. 742 743 Args: 744 cost: the absolute value of the cost. 745 max_cost: the maximum cost value to normalize the absolute cost with. 746 length: (int) length of the cost bar, in number of characters, excluding 747 the brackets on the two ends. 748 749 Returns: 750 An instance of debugger_cli_common.RichTextLine. 751 """ 752 num_ticks = int(np.ceil(float(cost) / max_cost * length)) 753 num_ticks = num_ticks or 1 # Minimum is 1 tick. 754 output = RL("[", font_attr=self._LINE_COST_ATTR) 755 output += RL("|" * num_ticks + " " * (length - num_ticks), 756 font_attr=["bold", self._LINE_COST_ATTR]) 757 output += RL("]", font_attr=self._LINE_COST_ATTR) 758 return output 759 760 def get_help(self, handler_name): 761 return self._arg_parsers[handler_name].format_help() 762 763 764def create_profiler_ui(graph, 765 run_metadata, 766 ui_type="curses", 767 on_ui_exit=None, 768 config=None): 769 """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`. 770 771 Args: 772 graph: Python `Graph` object. 773 run_metadata: A `RunMetadata` protobuf object. 774 ui_type: (str) requested UI type, e.g., "curses", "readline". 775 on_ui_exit: (`Callable`) the callback to be called when the UI exits. 776 config: An instance of `cli_config.CLIConfig`. 777 778 Returns: 779 (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer 780 commands and tab-completions registered. 781 """ 782 del config # Currently unused. 783 784 analyzer = ProfileAnalyzer(graph, run_metadata) 785 786 cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit) 787 cli.register_command_handler( 788 "list_profile", 789 analyzer.list_profile, 790 analyzer.get_help("list_profile"), 791 prefix_aliases=["lp"]) 792 cli.register_command_handler( 793 "print_source", 794 analyzer.print_source, 795 analyzer.get_help("print_source"), 796 prefix_aliases=["ps"]) 797 798 return cli 799