xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Intermediate layer between `Timer` and `valgrind`."""
2import collections
3import enum
4import dataclasses
5import itertools as it
6import os
7import pickle
8import re
9import shutil
10import subprocess
11import sys
12import textwrap
13from typing import (
14    cast, Any, Callable, DefaultDict, Dict, Iterator, List, NamedTuple,
15    Optional, Tuple, Union, TYPE_CHECKING)
16
17import torch
18from torch.utils.benchmark.utils import common, cpp_jit
19from torch.utils.benchmark.utils._stubs import CallgrindModuleType
20import operator
21
22
23__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
24
25
26if TYPE_CHECKING:
27    CompletedProcessType = subprocess.CompletedProcess[str]
28else:
29    CompletedProcessType = subprocess.CompletedProcess
30
31
32class FunctionCount(NamedTuple):
33    # TODO(#105471): Rename the count field
34    count: int  # type: ignore[assignment]
35    function: str
36
37
38@dataclasses.dataclass(repr=False, eq=False, frozen=True)
39class FunctionCounts:
40    """Container for manipulating Callgrind results.
41
42    It supports:
43        1) Addition and subtraction to combine or diff results.
44        2) Tuple-like indexing.
45        3) A `denoise` function which strips CPython calls which are known to
46           be non-deterministic and quite noisy.
47        4) Two higher order methods (`filter` and `transform`) for custom
48           manipulation.
49    """
50    _data: Tuple[FunctionCount, ...]
51    inclusive: bool
52    truncate_rows: bool = True
53
54    # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
55    # the print settings. This is simply to allow hermetic unit tests.
56    _linewidth: Optional[int] = None
57
58    def __iter__(self) -> Iterator[FunctionCount]:
59        yield from self._data
60
61    def __len__(self) -> int:
62        return len(self._data)
63
64    def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
65        data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item]
66        return (
67            FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
68            if isinstance(data, tuple) else data
69        )
70
71    def __repr__(self) -> str:
72        count_len = 0
73        for c, _ in self:
74            # Account for sign in string length.
75            count_len = max(count_len, len(str(c)) + int(c < 0))
76
77        lines = []
78        linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
79        fn_str_len = max(linewidth - count_len - 4, 40)
80        for c, fn in self:
81            if len(fn) > fn_str_len:
82                left_len = int((fn_str_len - 5) // 2)
83                fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
84            lines.append(f"  {c:>{count_len}}  {fn}")
85
86        if self.truncate_rows and len(lines) > 18:
87            lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
88
89        if not self.inclusive:
90            lines.extend(["", f"Total: {self.sum()}"])
91
92        return "\n".join([super().__repr__()] + lines)
93
94    def __add__(
95        self,
96        other: "FunctionCounts",
97    ) -> "FunctionCounts":
98        return self._merge(other, lambda c: c)
99
100    def __sub__(
101        self,
102        other: "FunctionCounts",
103    ) -> "FunctionCounts":
104        return self._merge(other, operator.neg)
105
106    def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
107        return self._from_dict({
108            fn: int(c * other) for c, fn in self._data
109        }, self.inclusive)
110
111    def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
112        """Apply `map_fn` to all of the function names.
113
114        This can be used to regularize function names (e.g. stripping irrelevant
115        parts of the file path), coalesce entries by mapping multiple functions
116        to the same name (in which case the counts are added together), etc.
117        """
118        counts: DefaultDict[str, int] = collections.defaultdict(int)
119        for c, fn in self._data:
120            counts[map_fn(fn)] += c
121
122        return self._from_dict(counts, self.inclusive)
123
124    def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
125        """Keep only the elements where `filter_fn` applied to function name returns True."""
126        return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
127
128    def sum(self) -> int:
129        return sum(c for c, _ in self)
130
131    def denoise(self) -> "FunctionCounts":
132        """Remove known noisy instructions.
133
134        Several instructions in the CPython interpreter are rather noisy. These
135        instructions involve unicode to dictionary lookups which Python uses to
136        map variable names. FunctionCounts is generally a content agnostic
137        container, however this is sufficiently important for obtaining
138        reliable results to warrant an exception."""
139        return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
140
141    def _merge(
142        self,
143        second: "FunctionCounts",
144        merge_fn: Callable[[int], int]
145    ) -> "FunctionCounts":
146        assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts."
147        counts: DefaultDict[str, int] = collections.defaultdict(int)
148        for c, fn in self:
149            counts[fn] += c
150
151        for c, fn in second:
152            counts[fn] += merge_fn(c)
153
154        return self._from_dict(counts, self.inclusive)
155
156    @staticmethod
157    def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts":
158        flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
159        return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
160
161
162@dataclasses.dataclass(repr=False, eq=False, frozen=True)
163class CallgrindStats:
164    """Top level container for Callgrind results collected by Timer.
165
166    Manipulation is generally done using the FunctionCounts class, which is
167    obtained by calling `CallgrindStats.stats(...)`. Several convenience
168    methods are provided as well; the most significant is
169    `CallgrindStats.as_standardized()`.
170    """
171    task_spec: common.TaskSpec
172    number_per_run: int
173    built_with_debug_symbols: bool
174    baseline_inclusive_stats: FunctionCounts
175    baseline_exclusive_stats: FunctionCounts
176    stmt_inclusive_stats: FunctionCounts
177    stmt_exclusive_stats: FunctionCounts
178    stmt_callgrind_out: Optional[str]
179
180    def __repr__(self) -> str:
181        newline = "\n"  # `\` cannot appear in fstring code section.
182        base_stats = self.baseline_exclusive_stats
183        output = f"""
184{super().__repr__()}
185{self.task_spec.summarize()}
186  {'':>25}All{'':>10}Noisy symbols removed
187    Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
188    Baseline:     {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
189{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
190""".strip()
191        if not self.built_with_debug_symbols:
192            output += textwrap.dedent("""
193            Warning: PyTorch was not built with debug symbols.
194                     Source information may be limited. Rebuild with
195                     REL_WITH_DEB_INFO=1 for more detailed results.""")
196        return output
197
198    def stats(self, inclusive: bool = False) -> FunctionCounts:
199        """Returns detailed function counts.
200
201        Conceptually, the FunctionCounts returned can be thought of as a tuple
202        of (count, path_and_function_name) tuples.
203
204        `inclusive` matches the semantics of callgrind. If True, the counts
205        include instructions executed by children. `inclusive=True` is useful
206        for identifying hot spots in code; `inclusive=False` is useful for
207        reducing noise when diffing counts from two different runs. (See
208        CallgrindStats.delta(...) for more details)
209        """
210        return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
211
212    def counts(self, *, denoise: bool = False) -> int:
213        """Returns the total number of instructions executed.
214
215        See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
216        """
217        stats = self.stmt_exclusive_stats
218        return (stats.denoise() if denoise else stats).sum()
219
220    # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
221    def delta(
222        self,
223        other: "CallgrindStats",
224        inclusive: bool = False,
225    ) -> FunctionCounts:
226        """Diff two sets of counts.
227
228        One common reason to collect instruction counts is to determine the
229        the effect that a particular change will have on the number of instructions
230        needed to perform some unit of work. If a change increases that number, the
231        next logical question is "why". This generally involves looking at what part
232        if the code increased in instruction count. This function automates that
233        process so that one can easily diff counts on both an inclusive and
234        exclusive basis.
235        """
236        return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
237
238    def as_standardized(self) -> "CallgrindStats":
239        """Strip library names and some prefixes from function strings.
240
241        When comparing two different sets of instruction counts, on stumbling
242        block can be path prefixes. Callgrind includes the full filepath
243        when reporting a function (as it should). However, this can cause
244        issues when diffing profiles. If a key component such as Python
245        or PyTorch was built in separate locations in the two profiles, which
246        can result in something resembling::
247
248            23234231 /tmp/first_build_dir/thing.c:foo(...)
249             9823794 /tmp/first_build_dir/thing.c:bar(...)
250              ...
251               53453 .../aten/src/Aten/...:function_that_actually_changed(...)
252              ...
253             -9823794 /tmp/second_build_dir/thing.c:bar(...)
254            -23234231 /tmp/second_build_dir/thing.c:foo(...)
255
256        Stripping prefixes can ameliorate this issue by regularizing the
257        strings and causing better cancellation of equivalent call sites
258        when diffing.
259        """
260        def strip(stats: FunctionCounts) -> FunctionCounts:
261            transforms = (
262                # PyTorch may have been built in different locations.
263                (r"^.+build/\.\./", "build/../"),
264                (r"^.+/" + re.escape("build/aten/"), "build/aten/"),
265
266                # "Python" and "Objects" come from CPython.
267                (r"^.+/" + re.escape("Python/"), "Python/"),
268                (r"^.+/" + re.escape("Objects/"), "Objects/"),
269
270                # Strip library name. e.g. `libtorch.so`
271                (r"\s\[.+\]$", ""),
272            )
273
274            for before, after in transforms:
275                stats = stats.transform(lambda fn: re.sub(before, after, fn))
276
277            return stats
278
279        return CallgrindStats(
280            task_spec=self.task_spec,
281            number_per_run=self.number_per_run,
282            built_with_debug_symbols=self.built_with_debug_symbols,
283            baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
284            baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
285            stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
286            stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
287
288            # `as_standardized` will change symbol names, so the contents will
289            # no longer map directly to `callgrind.out`
290            stmt_callgrind_out=None,
291        )
292
293
294class Serialization(enum.Enum):
295    PICKLE = 0
296    TORCH = 1
297    TORCH_JIT = 2
298
299
300_GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = {
301    Serialization.PICKLE: (str, bytes, bool, int, float, complex),
302    Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
303    Serialization.TORCH: (torch.nn.Module,),
304}
305
306
307class CopyIfCallgrind:
308    """Signal that a global may be replaced with a deserialized copy.
309
310    See `GlobalsBridge` for why this matters.
311    """
312    def __init__(self, value: Any, *, setup: Optional[str] = None):
313        for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
314            if any(isinstance(value, t) for t in supported_types):
315                self._value: Any = value
316                self._setup: Optional[str] = setup
317                self._serialization: Serialization = method
318                break
319        else:
320            supported_str = "\n".join([
321                getattr(t, "__name__", repr(t))
322                for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
323
324            raise ValueError(
325                f"Unsupported type: {type(value)}\n"
326                f"`collect_callgrind` restricts globals to the following types:\n"
327                f"{textwrap.indent(supported_str, '  ')}"
328            )
329
330    @property
331    def value(self) -> Any:
332        return self._value
333
334    @property
335    def setup(self) -> Optional[str]:
336        return self._setup
337
338    @property
339    def serialization(self) -> Serialization:
340        return self._serialization
341
342    @staticmethod
343    def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]:
344        return {
345            k: (v.value if isinstance(v, CopyIfCallgrind) else v)
346            for k, v in globals.items()
347        }
348
349
350class GlobalsBridge:
351    """Handle the transfer of (certain) globals when collecting Callgrind statistics.
352
353    Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
354                  work with `Timer.collect_callgrind`.
355
356    Consider the following code snippet:
357    ```
358        import pickle
359        import timeit
360
361        class Counter:
362            value = 0
363
364            def __call__(self):
365                self.value += 1
366
367        counter = Counter()
368        timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
369        print(counter.value)  # 10
370
371        timeit.Timer(
372            "counter()",
373            globals={"counter": pickle.loads(pickle.dumps(counter))}
374        ).timeit(20)
375        print(counter.value)  # Still 10
376    ```
377
378    In the first case, `stmt` is executed using the objects in `globals`;
379    however, the addition of serialization and deserialization changes the
380    semantics and may meaningfully change behavior.
381
382    This is a practical consideration when collecting Callgrind statistics.
383    Unlike `exec` based execution (which `timeit` uses under the hood) which
384    can share in-memory data structures with the caller, Callgrind collection
385    requires an entirely new process in order to run under Valgrind. This means
386    that any data structures used for statement execution will have to be
387    serialized and deserialized in the subprocess.
388
389    In order to avoid surprising semantics from (user invisible) process
390    boundaries, what can be passed through `globals` is severely restricted
391    for `Timer.collect_callgrind`. It is expected that most setup should be
392    achievable (albeit perhaps less ergonomically) by passing a `setup`
393    string.
394
395    There are, however, exceptions. One such class are TorchScripted functions.
396    Because they require a concrete file with source code it is not possible
397    to define them using a `setup` string. Another group are torch.nn.Modules,
398    whose construction can be complex and prohibitively cumbersome to coerce
399    into a `setup` string. Finally, most builtin types are sufficiently well
400    behaved and sufficiently common to warrant allowing as well. (e.g.
401    `globals={"n": 1}` is very convenient.)
402
403    Fortunately, all have well defined serialization semantics. This class
404    is responsible for enabling the Valgrind subprocess to use elements in
405    `globals` so long as they are an allowed type.
406
407    Caveats:
408        The user is required to acknowledge this serialization by wrapping
409        elements in `globals` with `CopyIfCallgrind`.
410
411        While ScriptFunction and ScriptModule are expected to save and load
412        quite robustly, it is up to the user to ensure that an nn.Module can
413        un-pickle successfully.
414
415        `torch.Tensor` and `np.ndarray` are deliberately excluded. The
416        serialization/deserialization process perturbs the representation of a
417        tensor in ways that could result in incorrect measurements. For example,
418        if a tensor lives in pinned CPU memory, this fact would not be preserved
419        by a dump, and that will in turn change the performance of certain CUDA
420        operations.
421    """
422
423    def __init__(self, globals: Dict[str, Any], data_dir: str) -> None:
424        self._globals: Dict[str, CopyIfCallgrind] = {}
425        self._data_dir = data_dir
426        if not os.path.exists(data_dir):
427            os.mkdir(data_dir)
428
429        if globals.get("torch", torch) is not torch:
430            raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
431
432        for name, value in globals.items():
433            if name in ("torch", "__builtins__"):
434                # Torch will be imported by the collection script, and
435                # __builtins__ is added by Timer.
436                continue
437
438            if not isinstance(value, CopyIfCallgrind):
439                raise ValueError(
440                    "`collect_callgrind` requires that globals be wrapped in "
441                    "`CopyIfCallgrind` so that serialization is explicit."
442                )
443
444            self._globals[name] = value
445
446    def construct(self) -> str:
447        load_lines = []
448        for name, wrapped_value in self._globals.items():
449            if wrapped_value.setup is not None:
450                load_lines.append(textwrap.dedent(wrapped_value.setup))
451
452            if wrapped_value.serialization == Serialization.PICKLE:
453                path = os.path.join(self._data_dir, f"{name}.pkl")
454                load_lines.append(
455                    f"with open({repr(path)}, 'rb') as f:\n    {name} = pickle.load(f)")
456                with open(path, "wb") as f:
457                    pickle.dump(wrapped_value.value, f)
458
459            elif wrapped_value.serialization == Serialization.TORCH:
460                path = os.path.join(self._data_dir, f"{name}.pt")
461                load_lines.append(f"{name} = torch.load({repr(path)})")
462                torch.save(wrapped_value.value, path)
463
464            elif wrapped_value.serialization == Serialization.TORCH_JIT:
465                path = os.path.join(self._data_dir, f"{name}.pt")
466                load_lines.append(f"{name} = torch.jit.load({repr(path)})")
467                with open(path, "wb") as f:
468                    torch.jit.save(wrapped_value.value, f)  # type: ignore[no-untyped-call]
469
470            else:
471                raise NotImplementedError(
472                    f"Unknown serialization method: {wrapped_value.serialization}")
473
474        return "\n".join(load_lines)
475
476
477class _ValgrindWrapper:
478    def __init__(self) -> None:
479        self._bindings_module: Optional[CallgrindModuleType] = None
480        valgrind_symbols = (
481            "_valgrind_supported_platform",
482            "_valgrind_toggle",
483            "_valgrind_toggle_and_dump_stats",
484        )
485        if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
486            self._supported_platform: bool = torch._C._valgrind_supported_platform()
487
488        else:
489            print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
490            self._bindings_module = cpp_jit.get_compat_bindings()
491            assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols)
492            self._supported_platform = self._bindings_module._valgrind_supported_platform()
493
494        self._commands_available: Dict[str, bool] = {}
495        if self._supported_platform:
496            # Only bother checking on supported platforms.
497            for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
498                self._commands_available[cmd] = not subprocess.run(
499                    ["which", cmd],
500                    capture_output=True,
501                    check=False,
502                ).returncode
503
504        self._build_type: Optional[str] = None
505        build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show())  # type: ignore[no-untyped-call]
506        if build_search is not None:
507            self._build_type = build_search.groups()[0].split(",")[0]
508
509    def _validate(self) -> None:
510        if not self._supported_platform:
511            raise OSError("Valgrind is not supported on this platform.")
512
513        missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
514        if missing_cmds:
515            raise OSError("Missing: " + ", ".join(missing_cmds))
516
517    def collect_callgrind(
518        self,
519        task_spec: common.TaskSpec,
520        globals: Dict[str, Any],
521        *,
522        number: int,
523        repeats: int,
524        collect_baseline: bool,
525        is_python: bool,
526        retain_out_file: bool,
527    ) -> Tuple[CallgrindStats, ...]:
528        """Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
529        self._validate()
530        assert is_python or not collect_baseline
531
532        *task_stats, baseline_stats = self._invoke(
533            task_spec=task_spec,
534            globals=globals,
535            number=number,
536            repeats=repeats,
537            collect_baseline=collect_baseline,
538            is_python=is_python,
539            retain_out_file=retain_out_file,
540        )
541        assert len(task_stats) == repeats
542
543        return tuple(
544            CallgrindStats(
545                task_spec=task_spec,
546                number_per_run=number,
547                built_with_debug_symbols=self._build_type == "RelWithDebInfo",
548                baseline_inclusive_stats=baseline_stats[0],
549                baseline_exclusive_stats=baseline_stats[1],
550                stmt_inclusive_stats=stmt_inclusive_stats,
551                stmt_exclusive_stats=stmt_exclusive_stats,
552                stmt_callgrind_out=out_contents,
553            )
554            for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
555        )
556
557    def _invoke(
558        self,
559        *,
560        task_spec: common.TaskSpec,
561        globals: Dict[str, Any],
562        number: int,
563        repeats: int,
564        collect_baseline: bool,
565        is_python: bool,
566        retain_out_file: bool,
567    ) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]:
568        """Core invocation method for Callgrind collection.
569
570        Valgrind operates by effectively replacing the CPU with an emulated
571        version which allows it to instrument any code at the cost of severe
572        performance degradation. This has the practical effect that in order
573        to collect Callgrind statistics, a new process has to be created
574        running under `valgrind`. The steps for this process are:
575
576        1) Create a scratch directory.
577        2) Codegen a run script. (_ValgrindWrapper._construct_script)
578            Inside the run script:
579                * Validate that Python and torch match the parent process
580                * Validate that it is indeed running under valgrind
581                * Execute `setup` and warm up `stmt`
582                * Begin collecting stats
583                * Run the `stmt` loop
584                * Stop collecting stats
585        3) Parse the run results.
586        4) Cleanup the scratch directory.
587        """
588        working_dir = common._make_temp_dir(prefix="callgrind")
589        data_dir = os.path.join(working_dir, "data")
590        script_file = os.path.join(working_dir, "timer_callgrind.py")
591        callgrind_out = os.path.join(working_dir, "callgrind.out")
592        error_log = os.path.join(working_dir, "error.txt")
593        stat_log = os.path.join(working_dir, "callgrind_stat.txt")
594        stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
595
596        def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]:
597            # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
598            f_stdout_stderr = open(stdout_stderr_log, "wb")
599            try:
600                invocation = subprocess.run(
601                    args,
602                    stdout=f_stdout_stderr,
603                    stderr=subprocess.STDOUT,
604                    **kwargs,
605                )
606                with open(stdout_stderr_log) as f:
607                    return invocation, f.read()
608            finally:
609                f_stdout_stderr.close()
610
611        try:
612            if is_python:
613                if self._bindings_module is not None:
614                    shutil.copy(
615                        self._bindings_module.__file__,
616                        os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
617                    )
618
619                script_file = os.path.join(working_dir, "timer_callgrind.py")
620                with open(script_file, "w") as f:
621                    f.write(self._construct_script(
622                        task_spec,
623                        globals=GlobalsBridge(globals, data_dir),
624                        number=number,
625                        repeats=repeats,
626                        collect_baseline=collect_baseline,
627                        error_log=error_log,
628                        stat_log=stat_log,
629                        bindings=self._bindings_module))
630
631                run_loop_cmd = ["python", script_file]
632            else:
633                assert not collect_baseline
634                run_loop_exec = cpp_jit.compile_callgrind_template(
635                    stmt=task_spec.stmt,
636                    setup=task_spec.setup,
637                    global_setup=task_spec.global_setup,
638                )
639                run_loop_cmd = [
640                    run_loop_exec,
641                    "--number", str(number),
642                    "--number-warmup", str(min(number, 10)),
643                    "--repeats", str(repeats),
644                    "--number-threads", str(task_spec.num_threads),
645                ]
646
647            valgrind_invocation, valgrind_invocation_output = run([
648                "valgrind",
649                "--tool=callgrind",
650                f"--callgrind-out-file={callgrind_out}",
651                "--dump-line=yes",
652                "--dump-instr=yes",
653                "--instr-atstart=yes",
654                "--collect-atstart=no",
655            ] + run_loop_cmd)
656
657            if valgrind_invocation.returncode:
658                error_report = ""
659                if os.path.exists(error_log):
660                    with open(error_log) as f:
661                        error_report = f.read()
662                if not error_report:
663                    error_report = "Unknown error.\n" + valgrind_invocation_output
664
665                raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
666
667            def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
668                annotate_invocation, annotate_invocation_output = run([
669                    "callgrind_annotate",
670                    f"--inclusive={'yes' if inclusive else 'no'}",
671                    "--threshold=100",
672                    "--show-percs=no",
673                    fpath
674                ], check=True)
675
676                total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
677                begin_pattern = re.compile(r"Ir\s+file:function")
678                function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
679
680                class ScanState(enum.Enum):
681                    SCANNING_FOR_TOTAL = 0
682                    SCANNING_FOR_START = 1
683                    PARSING = 2
684
685                scan_state = ScanState.SCANNING_FOR_TOTAL
686                fn_counts = []
687                for l in annotate_invocation_output.splitlines(keepends=False):
688                    if scan_state == ScanState.SCANNING_FOR_TOTAL:
689                        total_match = total_pattern.match(l)
690                        if total_match:
691                            program_totals = int(total_match.groups()[0].replace(",", ""))
692                            scan_state = ScanState.SCANNING_FOR_START
693
694                    elif scan_state == ScanState.SCANNING_FOR_START:
695                        if begin_pattern.match(l):
696                            scan_state = ScanState.PARSING
697
698                    else:
699                        assert scan_state == ScanState.PARSING
700                        fn_match = function_pattern.match(l)
701                        if fn_match:
702                            ir_str, file_function = fn_match.groups()
703                            ir = int(ir_str.replace(",", ""))
704                            if ir == program_totals:  # type: ignore[possibly-undefined]
705                                # Callgrind includes some top level red herring symbols when
706                                # a program dumps multiple profiles.
707                                continue
708                            fn_counts.append(FunctionCount(ir, file_function))
709
710                        elif re.match(r"-+", l):
711                            # Ignore heading separator lines.
712                            continue
713
714                        else:
715                            break
716
717                assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}"
718                return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
719
720            def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]:
721                if i == repeats and not collect_baseline:
722                    # Null baseline.
723                    return (
724                        FunctionCounts((), inclusive=True),
725                        FunctionCounts((), inclusive=False),
726                        None,
727                    )
728
729                fpath = f"{callgrind_out}.{i + 1}"  # Callgrind one-indexes files.
730                callgrind_out_contents: Optional[str] = None
731                if retain_out_file:
732                    with open(fpath) as f:
733                        callgrind_out_contents = f.read()
734
735                return (
736                    parse_output(fpath, inclusive=True),
737                    parse_output(fpath, inclusive=False),
738                    callgrind_out_contents
739                )
740
741            return tuple(read_results(i) for i in range(repeats + 1))
742        finally:
743            shutil.rmtree(working_dir)
744
745    @staticmethod
746    def _construct_script(
747        task_spec: common.TaskSpec,
748        globals: GlobalsBridge,
749        *,
750        number: int,
751        repeats: int,
752        collect_baseline: bool,
753        error_log: str,
754        stat_log: str,
755        bindings: Optional[CallgrindModuleType],
756    ) -> str:
757        def block_stmt(stmt: str, indent: int = 0) -> str:
758            """Partially unroll benchmark loop.
759
760            The naive template looks something like:
761                "for _ in range({number}): {stmt}"
762
763            However a loop in Python is surprisingly expensive, and significantly
764            increases the number of background Python instructions. So instead we
765            partially unroll the loops, with a block size of 100 chosen to keep
766            the instruction overhead from `range` low while also not ballooning
767            the size of the generated file.
768            """
769            block_size = 100
770            loop_count = number // block_size
771            if loop_count == 1:
772                # There is no point in having `for _ in range(1): ...` rather
773                # than just `...`, and this lets us save shave a few background
774                # instructions.
775                loop_count = 0
776            remainder = number - block_size * loop_count
777            blocked_stmt = ""
778
779            if loop_count:
780                unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
781                blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
782
783            if remainder:
784                blocked_stmt += "\n".join([stmt] * remainder)
785
786            return textwrap.indent(blocked_stmt, " " * indent)
787
788        pass_baseline = (
789            "callgrind_bindings._valgrind_toggle()\n"
790            f"{block_stmt('pass')}\n"
791            "callgrind_bindings._valgrind_toggle_and_dump_stats()"
792        )
793
794        return textwrap.dedent(r"""
795            import gc
796            import os
797            import pickle
798            import subprocess
799            import sys
800            import time
801
802            # Mitigate https://github.com/pytorch/pytorch/issues/37377
803            # which can sometimes cause the subprocess call to fail.
804            import numpy as np
805
806            import torch
807            torch.set_num_threads({num_threads})
808
809            {bindings_import}
810
811            PID = os.getpid()
812
813            def log_failure(msg):
814                with open({error_log_repr}, "wt") as f:
815                    f.write(msg)
816                sys.exit(1)
817
818            def check_result(completed_process):
819                if completed_process.returncode:
820                    log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
821                return completed_process
822
823            # =============================================================================
824            # == Check that subprocess matches parent =====================================
825            # =============================================================================
826            if os.path.realpath(sys.executable) != "{parent_interpreter}":
827                log_failure(
828                    "Interpreter mismatch:\n"
829                    f"  {{os.path.realpath(sys.executable)}}\n    vs.\n  {parent_interpreter}"
830                )
831
832            if torch.__file__ != "{torch_file}":
833                log_failure(
834                    "PyTorch does not match expected file:\n"
835                    f"  {{torch.__file__}}\n    vs.\n  {torch_file}"
836                )
837
838            # =============================================================================
839            # == User specified setup =====================================================
840            # =============================================================================
841            # Load serialized globals
842            {load_globals}
843
844            # User setup str
845            {setup}
846
847            for _ in range({warmup_number}):
848            {indented_stmt}
849
850            # =============================================================================
851            # == Callgrind management =====================================================
852            # =============================================================================
853            with open("{stat_log}", "wb") as stat_file:
854                # If many instances of callgrind are running at once, the output of
855                # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
856                # to deadlock. So instead we use a file.
857                callgrind_stat = check_result(subprocess.run(
858                    ["callgrind_control", "--stat"],
859                    stdout=stat_file,
860                    stderr=subprocess.STDOUT,
861                ))
862
863            with open("{stat_log}", "rt") as stat_file:
864                stat_lines = stat_file.read().splitlines()
865
866            if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
867                log_failure("Process does not appear to be running callgrind.")
868
869            gc.collect()
870            time.sleep(0.01)
871
872            # =============================================================================
873            # == User code block ==========================================================
874            # =============================================================================
875            for _ in range({repeats}):
876                callgrind_bindings._valgrind_toggle()
877            {blocked_stmt}
878                callgrind_bindings._valgrind_toggle_and_dump_stats()
879                gc.collect()
880
881            {baseline}
882        """).strip().format(
883            indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
884            blocked_stmt=block_stmt(task_spec.stmt, indent=4),
885            baseline=(pass_baseline if collect_baseline else ""),
886            number=number,
887            repeats=repeats,
888            load_globals=globals.construct(),
889            setup=task_spec.setup,
890            warmup_number=min(number, 10),
891            num_threads=task_spec.num_threads,
892            error_log_repr=repr(error_log),
893            stat_log=stat_log,
894            parent_interpreter=os.path.realpath(sys.executable),
895            torch_file=torch.__file__,
896            bindings_import=(
897                "import torch._C as callgrind_bindings" if bindings is None
898                else f"import {bindings.__name__} as callgrind_bindings"),
899        )
900
901
902CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None
903def wrapper_singleton() -> _ValgrindWrapper:
904    global CALLGRIND_SINGLETON
905    if CALLGRIND_SINGLETON is None:
906        CALLGRIND_SINGLETON = _ValgrindWrapper()
907    return CALLGRIND_SINGLETON
908