1"""Intermediate layer between `Timer` and `valgrind`.""" 2import collections 3import enum 4import dataclasses 5import itertools as it 6import os 7import pickle 8import re 9import shutil 10import subprocess 11import sys 12import textwrap 13from typing import ( 14 cast, Any, Callable, DefaultDict, Dict, Iterator, List, NamedTuple, 15 Optional, Tuple, Union, TYPE_CHECKING) 16 17import torch 18from torch.utils.benchmark.utils import common, cpp_jit 19from torch.utils.benchmark.utils._stubs import CallgrindModuleType 20import operator 21 22 23__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"] 24 25 26if TYPE_CHECKING: 27 CompletedProcessType = subprocess.CompletedProcess[str] 28else: 29 CompletedProcessType = subprocess.CompletedProcess 30 31 32class FunctionCount(NamedTuple): 33 # TODO(#105471): Rename the count field 34 count: int # type: ignore[assignment] 35 function: str 36 37 38@dataclasses.dataclass(repr=False, eq=False, frozen=True) 39class FunctionCounts: 40 """Container for manipulating Callgrind results. 41 42 It supports: 43 1) Addition and subtraction to combine or diff results. 44 2) Tuple-like indexing. 45 3) A `denoise` function which strips CPython calls which are known to 46 be non-deterministic and quite noisy. 47 4) Two higher order methods (`filter` and `transform`) for custom 48 manipulation. 49 """ 50 _data: Tuple[FunctionCount, ...] 51 inclusive: bool 52 truncate_rows: bool = True 53 54 # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines 55 # the print settings. This is simply to allow hermetic unit tests. 56 _linewidth: Optional[int] = None 57 58 def __iter__(self) -> Iterator[FunctionCount]: 59 yield from self._data 60 61 def __len__(self) -> int: 62 return len(self._data) 63 64 def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]: 65 data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item] 66 return ( 67 FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False) 68 if isinstance(data, tuple) else data 69 ) 70 71 def __repr__(self) -> str: 72 count_len = 0 73 for c, _ in self: 74 # Account for sign in string length. 75 count_len = max(count_len, len(str(c)) + int(c < 0)) 76 77 lines = [] 78 linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth 79 fn_str_len = max(linewidth - count_len - 4, 40) 80 for c, fn in self: 81 if len(fn) > fn_str_len: 82 left_len = int((fn_str_len - 5) // 2) 83 fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):] 84 lines.append(f" {c:>{count_len}} {fn}") 85 86 if self.truncate_rows and len(lines) > 18: 87 lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:] 88 89 if not self.inclusive: 90 lines.extend(["", f"Total: {self.sum()}"]) 91 92 return "\n".join([super().__repr__()] + lines) 93 94 def __add__( 95 self, 96 other: "FunctionCounts", 97 ) -> "FunctionCounts": 98 return self._merge(other, lambda c: c) 99 100 def __sub__( 101 self, 102 other: "FunctionCounts", 103 ) -> "FunctionCounts": 104 return self._merge(other, operator.neg) 105 106 def __mul__(self, other: Union[int, float]) -> "FunctionCounts": 107 return self._from_dict({ 108 fn: int(c * other) for c, fn in self._data 109 }, self.inclusive) 110 111 def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts": 112 """Apply `map_fn` to all of the function names. 113 114 This can be used to regularize function names (e.g. stripping irrelevant 115 parts of the file path), coalesce entries by mapping multiple functions 116 to the same name (in which case the counts are added together), etc. 117 """ 118 counts: DefaultDict[str, int] = collections.defaultdict(int) 119 for c, fn in self._data: 120 counts[map_fn(fn)] += c 121 122 return self._from_dict(counts, self.inclusive) 123 124 def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts": 125 """Keep only the elements where `filter_fn` applied to function name returns True.""" 126 return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive) 127 128 def sum(self) -> int: 129 return sum(c for c, _ in self) 130 131 def denoise(self) -> "FunctionCounts": 132 """Remove known noisy instructions. 133 134 Several instructions in the CPython interpreter are rather noisy. These 135 instructions involve unicode to dictionary lookups which Python uses to 136 map variable names. FunctionCounts is generally a content agnostic 137 container, however this is sufficiently important for obtaining 138 reliable results to warrant an exception.""" 139 return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn) 140 141 def _merge( 142 self, 143 second: "FunctionCounts", 144 merge_fn: Callable[[int], int] 145 ) -> "FunctionCounts": 146 assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts." 147 counts: DefaultDict[str, int] = collections.defaultdict(int) 148 for c, fn in self: 149 counts[fn] += c 150 151 for c, fn in second: 152 counts[fn] += merge_fn(c) 153 154 return self._from_dict(counts, self.inclusive) 155 156 @staticmethod 157 def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts": 158 flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c) 159 return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive) 160 161 162@dataclasses.dataclass(repr=False, eq=False, frozen=True) 163class CallgrindStats: 164 """Top level container for Callgrind results collected by Timer. 165 166 Manipulation is generally done using the FunctionCounts class, which is 167 obtained by calling `CallgrindStats.stats(...)`. Several convenience 168 methods are provided as well; the most significant is 169 `CallgrindStats.as_standardized()`. 170 """ 171 task_spec: common.TaskSpec 172 number_per_run: int 173 built_with_debug_symbols: bool 174 baseline_inclusive_stats: FunctionCounts 175 baseline_exclusive_stats: FunctionCounts 176 stmt_inclusive_stats: FunctionCounts 177 stmt_exclusive_stats: FunctionCounts 178 stmt_callgrind_out: Optional[str] 179 180 def __repr__(self) -> str: 181 newline = "\n" # `\` cannot appear in fstring code section. 182 base_stats = self.baseline_exclusive_stats 183 output = f""" 184{super().__repr__()} 185{self.task_spec.summarize()} 186 {'':>25}All{'':>10}Noisy symbols removed 187 Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12} 188 Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12} 189{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''} 190""".strip() 191 if not self.built_with_debug_symbols: 192 output += textwrap.dedent(""" 193 Warning: PyTorch was not built with debug symbols. 194 Source information may be limited. Rebuild with 195 REL_WITH_DEB_INFO=1 for more detailed results.""") 196 return output 197 198 def stats(self, inclusive: bool = False) -> FunctionCounts: 199 """Returns detailed function counts. 200 201 Conceptually, the FunctionCounts returned can be thought of as a tuple 202 of (count, path_and_function_name) tuples. 203 204 `inclusive` matches the semantics of callgrind. If True, the counts 205 include instructions executed by children. `inclusive=True` is useful 206 for identifying hot spots in code; `inclusive=False` is useful for 207 reducing noise when diffing counts from two different runs. (See 208 CallgrindStats.delta(...) for more details) 209 """ 210 return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats 211 212 def counts(self, *, denoise: bool = False) -> int: 213 """Returns the total number of instructions executed. 214 215 See `FunctionCounts.denoise()` for an explanation of the `denoise` arg. 216 """ 217 stats = self.stmt_exclusive_stats 218 return (stats.denoise() if denoise else stats).sum() 219 220 # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563 221 def delta( 222 self, 223 other: "CallgrindStats", 224 inclusive: bool = False, 225 ) -> FunctionCounts: 226 """Diff two sets of counts. 227 228 One common reason to collect instruction counts is to determine the 229 the effect that a particular change will have on the number of instructions 230 needed to perform some unit of work. If a change increases that number, the 231 next logical question is "why". This generally involves looking at what part 232 if the code increased in instruction count. This function automates that 233 process so that one can easily diff counts on both an inclusive and 234 exclusive basis. 235 """ 236 return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive) 237 238 def as_standardized(self) -> "CallgrindStats": 239 """Strip library names and some prefixes from function strings. 240 241 When comparing two different sets of instruction counts, on stumbling 242 block can be path prefixes. Callgrind includes the full filepath 243 when reporting a function (as it should). However, this can cause 244 issues when diffing profiles. If a key component such as Python 245 or PyTorch was built in separate locations in the two profiles, which 246 can result in something resembling:: 247 248 23234231 /tmp/first_build_dir/thing.c:foo(...) 249 9823794 /tmp/first_build_dir/thing.c:bar(...) 250 ... 251 53453 .../aten/src/Aten/...:function_that_actually_changed(...) 252 ... 253 -9823794 /tmp/second_build_dir/thing.c:bar(...) 254 -23234231 /tmp/second_build_dir/thing.c:foo(...) 255 256 Stripping prefixes can ameliorate this issue by regularizing the 257 strings and causing better cancellation of equivalent call sites 258 when diffing. 259 """ 260 def strip(stats: FunctionCounts) -> FunctionCounts: 261 transforms = ( 262 # PyTorch may have been built in different locations. 263 (r"^.+build/\.\./", "build/../"), 264 (r"^.+/" + re.escape("build/aten/"), "build/aten/"), 265 266 # "Python" and "Objects" come from CPython. 267 (r"^.+/" + re.escape("Python/"), "Python/"), 268 (r"^.+/" + re.escape("Objects/"), "Objects/"), 269 270 # Strip library name. e.g. `libtorch.so` 271 (r"\s\[.+\]$", ""), 272 ) 273 274 for before, after in transforms: 275 stats = stats.transform(lambda fn: re.sub(before, after, fn)) 276 277 return stats 278 279 return CallgrindStats( 280 task_spec=self.task_spec, 281 number_per_run=self.number_per_run, 282 built_with_debug_symbols=self.built_with_debug_symbols, 283 baseline_inclusive_stats=strip(self.baseline_inclusive_stats), 284 baseline_exclusive_stats=strip(self.baseline_exclusive_stats), 285 stmt_inclusive_stats=strip(self.stmt_inclusive_stats), 286 stmt_exclusive_stats=strip(self.stmt_exclusive_stats), 287 288 # `as_standardized` will change symbol names, so the contents will 289 # no longer map directly to `callgrind.out` 290 stmt_callgrind_out=None, 291 ) 292 293 294class Serialization(enum.Enum): 295 PICKLE = 0 296 TORCH = 1 297 TORCH_JIT = 2 298 299 300_GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = { 301 Serialization.PICKLE: (str, bytes, bool, int, float, complex), 302 Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule), 303 Serialization.TORCH: (torch.nn.Module,), 304} 305 306 307class CopyIfCallgrind: 308 """Signal that a global may be replaced with a deserialized copy. 309 310 See `GlobalsBridge` for why this matters. 311 """ 312 def __init__(self, value: Any, *, setup: Optional[str] = None): 313 for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): 314 if any(isinstance(value, t) for t in supported_types): 315 self._value: Any = value 316 self._setup: Optional[str] = setup 317 self._serialization: Serialization = method 318 break 319 else: 320 supported_str = "\n".join([ 321 getattr(t, "__name__", repr(t)) 322 for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())]) 323 324 raise ValueError( 325 f"Unsupported type: {type(value)}\n" 326 f"`collect_callgrind` restricts globals to the following types:\n" 327 f"{textwrap.indent(supported_str, ' ')}" 328 ) 329 330 @property 331 def value(self) -> Any: 332 return self._value 333 334 @property 335 def setup(self) -> Optional[str]: 336 return self._setup 337 338 @property 339 def serialization(self) -> Serialization: 340 return self._serialization 341 342 @staticmethod 343 def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]: 344 return { 345 k: (v.value if isinstance(v, CopyIfCallgrind) else v) 346 for k, v in globals.items() 347 } 348 349 350class GlobalsBridge: 351 """Handle the transfer of (certain) globals when collecting Callgrind statistics. 352 353 Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to 354 work with `Timer.collect_callgrind`. 355 356 Consider the following code snippet: 357 ``` 358 import pickle 359 import timeit 360 361 class Counter: 362 value = 0 363 364 def __call__(self): 365 self.value += 1 366 367 counter = Counter() 368 timeit.Timer("counter()", globals={"counter": counter}).timeit(10) 369 print(counter.value) # 10 370 371 timeit.Timer( 372 "counter()", 373 globals={"counter": pickle.loads(pickle.dumps(counter))} 374 ).timeit(20) 375 print(counter.value) # Still 10 376 ``` 377 378 In the first case, `stmt` is executed using the objects in `globals`; 379 however, the addition of serialization and deserialization changes the 380 semantics and may meaningfully change behavior. 381 382 This is a practical consideration when collecting Callgrind statistics. 383 Unlike `exec` based execution (which `timeit` uses under the hood) which 384 can share in-memory data structures with the caller, Callgrind collection 385 requires an entirely new process in order to run under Valgrind. This means 386 that any data structures used for statement execution will have to be 387 serialized and deserialized in the subprocess. 388 389 In order to avoid surprising semantics from (user invisible) process 390 boundaries, what can be passed through `globals` is severely restricted 391 for `Timer.collect_callgrind`. It is expected that most setup should be 392 achievable (albeit perhaps less ergonomically) by passing a `setup` 393 string. 394 395 There are, however, exceptions. One such class are TorchScripted functions. 396 Because they require a concrete file with source code it is not possible 397 to define them using a `setup` string. Another group are torch.nn.Modules, 398 whose construction can be complex and prohibitively cumbersome to coerce 399 into a `setup` string. Finally, most builtin types are sufficiently well 400 behaved and sufficiently common to warrant allowing as well. (e.g. 401 `globals={"n": 1}` is very convenient.) 402 403 Fortunately, all have well defined serialization semantics. This class 404 is responsible for enabling the Valgrind subprocess to use elements in 405 `globals` so long as they are an allowed type. 406 407 Caveats: 408 The user is required to acknowledge this serialization by wrapping 409 elements in `globals` with `CopyIfCallgrind`. 410 411 While ScriptFunction and ScriptModule are expected to save and load 412 quite robustly, it is up to the user to ensure that an nn.Module can 413 un-pickle successfully. 414 415 `torch.Tensor` and `np.ndarray` are deliberately excluded. The 416 serialization/deserialization process perturbs the representation of a 417 tensor in ways that could result in incorrect measurements. For example, 418 if a tensor lives in pinned CPU memory, this fact would not be preserved 419 by a dump, and that will in turn change the performance of certain CUDA 420 operations. 421 """ 422 423 def __init__(self, globals: Dict[str, Any], data_dir: str) -> None: 424 self._globals: Dict[str, CopyIfCallgrind] = {} 425 self._data_dir = data_dir 426 if not os.path.exists(data_dir): 427 os.mkdir(data_dir) 428 429 if globals.get("torch", torch) is not torch: 430 raise ValueError("`collect_callgrind` does not support mocking out `torch`.") 431 432 for name, value in globals.items(): 433 if name in ("torch", "__builtins__"): 434 # Torch will be imported by the collection script, and 435 # __builtins__ is added by Timer. 436 continue 437 438 if not isinstance(value, CopyIfCallgrind): 439 raise ValueError( 440 "`collect_callgrind` requires that globals be wrapped in " 441 "`CopyIfCallgrind` so that serialization is explicit." 442 ) 443 444 self._globals[name] = value 445 446 def construct(self) -> str: 447 load_lines = [] 448 for name, wrapped_value in self._globals.items(): 449 if wrapped_value.setup is not None: 450 load_lines.append(textwrap.dedent(wrapped_value.setup)) 451 452 if wrapped_value.serialization == Serialization.PICKLE: 453 path = os.path.join(self._data_dir, f"{name}.pkl") 454 load_lines.append( 455 f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") 456 with open(path, "wb") as f: 457 pickle.dump(wrapped_value.value, f) 458 459 elif wrapped_value.serialization == Serialization.TORCH: 460 path = os.path.join(self._data_dir, f"{name}.pt") 461 load_lines.append(f"{name} = torch.load({repr(path)})") 462 torch.save(wrapped_value.value, path) 463 464 elif wrapped_value.serialization == Serialization.TORCH_JIT: 465 path = os.path.join(self._data_dir, f"{name}.pt") 466 load_lines.append(f"{name} = torch.jit.load({repr(path)})") 467 with open(path, "wb") as f: 468 torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call] 469 470 else: 471 raise NotImplementedError( 472 f"Unknown serialization method: {wrapped_value.serialization}") 473 474 return "\n".join(load_lines) 475 476 477class _ValgrindWrapper: 478 def __init__(self) -> None: 479 self._bindings_module: Optional[CallgrindModuleType] = None 480 valgrind_symbols = ( 481 "_valgrind_supported_platform", 482 "_valgrind_toggle", 483 "_valgrind_toggle_and_dump_stats", 484 ) 485 if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols): 486 self._supported_platform: bool = torch._C._valgrind_supported_platform() 487 488 else: 489 print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.") 490 self._bindings_module = cpp_jit.get_compat_bindings() 491 assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols) 492 self._supported_platform = self._bindings_module._valgrind_supported_platform() 493 494 self._commands_available: Dict[str, bool] = {} 495 if self._supported_platform: 496 # Only bother checking on supported platforms. 497 for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"): 498 self._commands_available[cmd] = not subprocess.run( 499 ["which", cmd], 500 capture_output=True, 501 check=False, 502 ).returncode 503 504 self._build_type: Optional[str] = None 505 build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call] 506 if build_search is not None: 507 self._build_type = build_search.groups()[0].split(",")[0] 508 509 def _validate(self) -> None: 510 if not self._supported_platform: 511 raise OSError("Valgrind is not supported on this platform.") 512 513 missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available] 514 if missing_cmds: 515 raise OSError("Missing: " + ", ".join(missing_cmds)) 516 517 def collect_callgrind( 518 self, 519 task_spec: common.TaskSpec, 520 globals: Dict[str, Any], 521 *, 522 number: int, 523 repeats: int, 524 collect_baseline: bool, 525 is_python: bool, 526 retain_out_file: bool, 527 ) -> Tuple[CallgrindStats, ...]: 528 """Collect stats, and attach a reference run which can be used to filter interpreter overhead.""" 529 self._validate() 530 assert is_python or not collect_baseline 531 532 *task_stats, baseline_stats = self._invoke( 533 task_spec=task_spec, 534 globals=globals, 535 number=number, 536 repeats=repeats, 537 collect_baseline=collect_baseline, 538 is_python=is_python, 539 retain_out_file=retain_out_file, 540 ) 541 assert len(task_stats) == repeats 542 543 return tuple( 544 CallgrindStats( 545 task_spec=task_spec, 546 number_per_run=number, 547 built_with_debug_symbols=self._build_type == "RelWithDebInfo", 548 baseline_inclusive_stats=baseline_stats[0], 549 baseline_exclusive_stats=baseline_stats[1], 550 stmt_inclusive_stats=stmt_inclusive_stats, 551 stmt_exclusive_stats=stmt_exclusive_stats, 552 stmt_callgrind_out=out_contents, 553 ) 554 for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats 555 ) 556 557 def _invoke( 558 self, 559 *, 560 task_spec: common.TaskSpec, 561 globals: Dict[str, Any], 562 number: int, 563 repeats: int, 564 collect_baseline: bool, 565 is_python: bool, 566 retain_out_file: bool, 567 ) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]: 568 """Core invocation method for Callgrind collection. 569 570 Valgrind operates by effectively replacing the CPU with an emulated 571 version which allows it to instrument any code at the cost of severe 572 performance degradation. This has the practical effect that in order 573 to collect Callgrind statistics, a new process has to be created 574 running under `valgrind`. The steps for this process are: 575 576 1) Create a scratch directory. 577 2) Codegen a run script. (_ValgrindWrapper._construct_script) 578 Inside the run script: 579 * Validate that Python and torch match the parent process 580 * Validate that it is indeed running under valgrind 581 * Execute `setup` and warm up `stmt` 582 * Begin collecting stats 583 * Run the `stmt` loop 584 * Stop collecting stats 585 3) Parse the run results. 586 4) Cleanup the scratch directory. 587 """ 588 working_dir = common._make_temp_dir(prefix="callgrind") 589 data_dir = os.path.join(working_dir, "data") 590 script_file = os.path.join(working_dir, "timer_callgrind.py") 591 callgrind_out = os.path.join(working_dir, "callgrind.out") 592 error_log = os.path.join(working_dir, "error.txt") 593 stat_log = os.path.join(working_dir, "callgrind_stat.txt") 594 stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log") 595 596 def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]: 597 # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ 598 f_stdout_stderr = open(stdout_stderr_log, "wb") 599 try: 600 invocation = subprocess.run( 601 args, 602 stdout=f_stdout_stderr, 603 stderr=subprocess.STDOUT, 604 **kwargs, 605 ) 606 with open(stdout_stderr_log) as f: 607 return invocation, f.read() 608 finally: 609 f_stdout_stderr.close() 610 611 try: 612 if is_python: 613 if self._bindings_module is not None: 614 shutil.copy( 615 self._bindings_module.__file__, 616 os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1]) 617 ) 618 619 script_file = os.path.join(working_dir, "timer_callgrind.py") 620 with open(script_file, "w") as f: 621 f.write(self._construct_script( 622 task_spec, 623 globals=GlobalsBridge(globals, data_dir), 624 number=number, 625 repeats=repeats, 626 collect_baseline=collect_baseline, 627 error_log=error_log, 628 stat_log=stat_log, 629 bindings=self._bindings_module)) 630 631 run_loop_cmd = ["python", script_file] 632 else: 633 assert not collect_baseline 634 run_loop_exec = cpp_jit.compile_callgrind_template( 635 stmt=task_spec.stmt, 636 setup=task_spec.setup, 637 global_setup=task_spec.global_setup, 638 ) 639 run_loop_cmd = [ 640 run_loop_exec, 641 "--number", str(number), 642 "--number-warmup", str(min(number, 10)), 643 "--repeats", str(repeats), 644 "--number-threads", str(task_spec.num_threads), 645 ] 646 647 valgrind_invocation, valgrind_invocation_output = run([ 648 "valgrind", 649 "--tool=callgrind", 650 f"--callgrind-out-file={callgrind_out}", 651 "--dump-line=yes", 652 "--dump-instr=yes", 653 "--instr-atstart=yes", 654 "--collect-atstart=no", 655 ] + run_loop_cmd) 656 657 if valgrind_invocation.returncode: 658 error_report = "" 659 if os.path.exists(error_log): 660 with open(error_log) as f: 661 error_report = f.read() 662 if not error_report: 663 error_report = "Unknown error.\n" + valgrind_invocation_output 664 665 raise OSError(f"Failed to collect callgrind profile:\n{error_report}") 666 667 def parse_output(fpath: str, inclusive: bool) -> FunctionCounts: 668 annotate_invocation, annotate_invocation_output = run([ 669 "callgrind_annotate", 670 f"--inclusive={'yes' if inclusive else 'no'}", 671 "--threshold=100", 672 "--show-percs=no", 673 fpath 674 ], check=True) 675 676 total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS") 677 begin_pattern = re.compile(r"Ir\s+file:function") 678 function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$") 679 680 class ScanState(enum.Enum): 681 SCANNING_FOR_TOTAL = 0 682 SCANNING_FOR_START = 1 683 PARSING = 2 684 685 scan_state = ScanState.SCANNING_FOR_TOTAL 686 fn_counts = [] 687 for l in annotate_invocation_output.splitlines(keepends=False): 688 if scan_state == ScanState.SCANNING_FOR_TOTAL: 689 total_match = total_pattern.match(l) 690 if total_match: 691 program_totals = int(total_match.groups()[0].replace(",", "")) 692 scan_state = ScanState.SCANNING_FOR_START 693 694 elif scan_state == ScanState.SCANNING_FOR_START: 695 if begin_pattern.match(l): 696 scan_state = ScanState.PARSING 697 698 else: 699 assert scan_state == ScanState.PARSING 700 fn_match = function_pattern.match(l) 701 if fn_match: 702 ir_str, file_function = fn_match.groups() 703 ir = int(ir_str.replace(",", "")) 704 if ir == program_totals: # type: ignore[possibly-undefined] 705 # Callgrind includes some top level red herring symbols when 706 # a program dumps multiple profiles. 707 continue 708 fn_counts.append(FunctionCount(ir, file_function)) 709 710 elif re.match(r"-+", l): 711 # Ignore heading separator lines. 712 continue 713 714 else: 715 break 716 717 assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}" 718 return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) 719 720 def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]: 721 if i == repeats and not collect_baseline: 722 # Null baseline. 723 return ( 724 FunctionCounts((), inclusive=True), 725 FunctionCounts((), inclusive=False), 726 None, 727 ) 728 729 fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files. 730 callgrind_out_contents: Optional[str] = None 731 if retain_out_file: 732 with open(fpath) as f: 733 callgrind_out_contents = f.read() 734 735 return ( 736 parse_output(fpath, inclusive=True), 737 parse_output(fpath, inclusive=False), 738 callgrind_out_contents 739 ) 740 741 return tuple(read_results(i) for i in range(repeats + 1)) 742 finally: 743 shutil.rmtree(working_dir) 744 745 @staticmethod 746 def _construct_script( 747 task_spec: common.TaskSpec, 748 globals: GlobalsBridge, 749 *, 750 number: int, 751 repeats: int, 752 collect_baseline: bool, 753 error_log: str, 754 stat_log: str, 755 bindings: Optional[CallgrindModuleType], 756 ) -> str: 757 def block_stmt(stmt: str, indent: int = 0) -> str: 758 """Partially unroll benchmark loop. 759 760 The naive template looks something like: 761 "for _ in range({number}): {stmt}" 762 763 However a loop in Python is surprisingly expensive, and significantly 764 increases the number of background Python instructions. So instead we 765 partially unroll the loops, with a block size of 100 chosen to keep 766 the instruction overhead from `range` low while also not ballooning 767 the size of the generated file. 768 """ 769 block_size = 100 770 loop_count = number // block_size 771 if loop_count == 1: 772 # There is no point in having `for _ in range(1): ...` rather 773 # than just `...`, and this lets us save shave a few background 774 # instructions. 775 loop_count = 0 776 remainder = number - block_size * loop_count 777 blocked_stmt = "" 778 779 if loop_count: 780 unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4) 781 blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n" 782 783 if remainder: 784 blocked_stmt += "\n".join([stmt] * remainder) 785 786 return textwrap.indent(blocked_stmt, " " * indent) 787 788 pass_baseline = ( 789 "callgrind_bindings._valgrind_toggle()\n" 790 f"{block_stmt('pass')}\n" 791 "callgrind_bindings._valgrind_toggle_and_dump_stats()" 792 ) 793 794 return textwrap.dedent(r""" 795 import gc 796 import os 797 import pickle 798 import subprocess 799 import sys 800 import time 801 802 # Mitigate https://github.com/pytorch/pytorch/issues/37377 803 # which can sometimes cause the subprocess call to fail. 804 import numpy as np 805 806 import torch 807 torch.set_num_threads({num_threads}) 808 809 {bindings_import} 810 811 PID = os.getpid() 812 813 def log_failure(msg): 814 with open({error_log_repr}, "wt") as f: 815 f.write(msg) 816 sys.exit(1) 817 818 def check_result(completed_process): 819 if completed_process.returncode: 820 log_failure(f"Command failed: {{' '.join(completed_process.args)}}") 821 return completed_process 822 823 # ============================================================================= 824 # == Check that subprocess matches parent ===================================== 825 # ============================================================================= 826 if os.path.realpath(sys.executable) != "{parent_interpreter}": 827 log_failure( 828 "Interpreter mismatch:\n" 829 f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}" 830 ) 831 832 if torch.__file__ != "{torch_file}": 833 log_failure( 834 "PyTorch does not match expected file:\n" 835 f" {{torch.__file__}}\n vs.\n {torch_file}" 836 ) 837 838 # ============================================================================= 839 # == User specified setup ===================================================== 840 # ============================================================================= 841 # Load serialized globals 842 {load_globals} 843 844 # User setup str 845 {setup} 846 847 for _ in range({warmup_number}): 848 {indented_stmt} 849 850 # ============================================================================= 851 # == Callgrind management ===================================================== 852 # ============================================================================= 853 with open("{stat_log}", "wb") as stat_file: 854 # If many instances of callgrind are running at once, the output of 855 # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE` 856 # to deadlock. So instead we use a file. 857 callgrind_stat = check_result(subprocess.run( 858 ["callgrind_control", "--stat"], 859 stdout=stat_file, 860 stderr=subprocess.STDOUT, 861 )) 862 863 with open("{stat_log}", "rt") as stat_file: 864 stat_lines = stat_file.read().splitlines() 865 866 if f"PID {{PID}}: python {{__file__}}" not in stat_lines: 867 log_failure("Process does not appear to be running callgrind.") 868 869 gc.collect() 870 time.sleep(0.01) 871 872 # ============================================================================= 873 # == User code block ========================================================== 874 # ============================================================================= 875 for _ in range({repeats}): 876 callgrind_bindings._valgrind_toggle() 877 {blocked_stmt} 878 callgrind_bindings._valgrind_toggle_and_dump_stats() 879 gc.collect() 880 881 {baseline} 882 """).strip().format( 883 indented_stmt=textwrap.indent(task_spec.stmt, " " * 4), 884 blocked_stmt=block_stmt(task_spec.stmt, indent=4), 885 baseline=(pass_baseline if collect_baseline else ""), 886 number=number, 887 repeats=repeats, 888 load_globals=globals.construct(), 889 setup=task_spec.setup, 890 warmup_number=min(number, 10), 891 num_threads=task_spec.num_threads, 892 error_log_repr=repr(error_log), 893 stat_log=stat_log, 894 parent_interpreter=os.path.realpath(sys.executable), 895 torch_file=torch.__file__, 896 bindings_import=( 897 "import torch._C as callgrind_bindings" if bindings is None 898 else f"import {bindings.__name__} as callgrind_bindings"), 899 ) 900 901 902CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None 903def wrapper_singleton() -> _ValgrindWrapper: 904 global CALLGRIND_SINGLETON 905 if CALLGRIND_SINGLETON is None: 906 CALLGRIND_SINGLETON = _ValgrindWrapper() 907 return CALLGRIND_SINGLETON 908