xref: /aosp_15_r20/external/pytorch/torch/_dynamo/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import os
4from typing import Any, List
5
6import torch
7
8from .utils import print_once
9
10
11@dataclasses.dataclass
12class ProfileMetrics:
13    microseconds: float = 0.0
14    operators: int = 0
15    fusions: int = 0
16    graphs: int = 0
17
18    def __iadd__(self, other: "ProfileMetrics"):
19        self.microseconds += other.microseconds
20        self.operators += other.operators
21        self.fusions += other.fusions
22        return self
23
24    def __add__(self, other: "ProfileMetrics"):
25        assert isinstance(other, ProfileMetrics)
26        return ProfileMetrics(
27            self.microseconds + other.microseconds,
28            self.operators + other.operators,
29            self.fusions + other.fusions,
30        )
31
32    def __truediv__(self, other):
33        if isinstance(other, int):
34            other = ProfileMetrics(other, other, other)
35        return ProfileMetrics(
36            self.microseconds / max(1, other.microseconds),
37            self.operators / max(1, other.operators),
38            self.fusions / max(1, other.fusions),
39        )
40
41    def __str__(self) -> str:
42        return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
43
44    def tocsv(self):
45        return [self.operators, self.microseconds]
46
47
48class ProfileResult:
49    def __init__(self, captured, total, unique_graphs) -> None:
50        self.captured: ProfileMetrics = captured or ProfileMetrics()
51        self.total: ProfileMetrics = total or ProfileMetrics()
52        self.unique_graphs: int = unique_graphs
53
54    def __iadd__(self, other: "ProfileResult"):
55        self.captured += other.captured
56        self.total += other.total
57        self.unique_graphs += other.unique_graphs
58        return self
59
60    def percent(self):
61        return self.captured / self.total
62
63    def __str__(self) -> str:
64        return (
65            f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
66            f"{self.captured.operators:4}/{self.total.operators:4} = "
67            + str(self.percent())
68        )
69
70    def tocsv(self):
71        return [
72            self.unique_graphs,
73            self.captured.graphs,
74            self.captured.operators,
75            self.total.operators,
76        ] + self.percent().tocsv()
77
78
79def should_print_missing():
80    return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
81
82
83def print_missing(stack):
84    if any("/torch/autograd/profiler.py" in x for x in stack):
85        return
86    stack = [
87        x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
88    ]
89    print_once("MISSING", " >> ".join(stack[-3:]))
90
91
92class Profiler:
93    unique_graphs = 0
94
95    def __init__(self) -> None:
96        self.prof = torch.profiler.profile(
97            activities=[torch.profiler.ProfilerActivity.CPU],
98            with_stack=should_print_missing(),
99        )
100
101    def results(self):
102        captured_regions = 0
103        captured_ops = 0
104        captured_microseconds = 0
105        total_ops = 0
106        total_microseconds = 0
107
108        last_op_end_time = -1
109        captured_region_end_time = -1
110        events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
111        for e in events:
112            if e.name == "TORCHDYNAMO":
113                captured_region_end_time = e.time_range.end
114                captured_regions += 1
115                # ignore `handle = torch.zeros(1)` in record_function.__init__()
116                total_ops -= 1
117            elif e.time_range.start >= last_op_end_time:
118                last_op_end_time = e.time_range.end
119                if e.time_range.end <= captured_region_end_time:
120                    captured_ops += 1
121                    captured_microseconds += e.time_range.elapsed_us()
122                elif should_print_missing():
123                    print_missing(e.stack)
124                total_ops += 1
125                total_microseconds += e.time_range.elapsed_us()
126            else:
127                pass  # ops recursively called from other ops (ignored)
128
129        unique_graphs = Profiler.unique_graphs
130        Profiler.unique_graphs = 0
131        # we counted one extra op that is part of the profiler setup code
132        total_ops -= 1
133
134        return ProfileResult(
135            captured=ProfileMetrics(
136                microseconds=captured_microseconds,
137                operators=captured_ops,
138                fusions=captured_ops - captured_regions,
139                graphs=captured_regions,
140            ),
141            total=ProfileMetrics(
142                microseconds=total_microseconds,
143                operators=total_ops,
144                fusions=total_ops - 1,
145            ),
146            unique_graphs=unique_graphs,
147        )
148
149
150def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]):
151    def _wrapped(*args):
152        with torch.profiler.record_function("TORCHDYNAMO"):
153            return gm.forward(*args)
154
155    Profiler.unique_graphs += 1
156    return _wrapped
157