xref: /aosp_15_r20/external/pytorch/torch/autograd/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
4*da0073e9SAndroid Build Coastguard Workerfrom time import perf_counter_ns
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterable, List, Optional
6*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
10*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _get_privateuse1_backend_name
11*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import _ExperimentalConfig
12*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import (
13*da0073e9SAndroid Build Coastguard Worker    _disable_profiler,
14*da0073e9SAndroid Build Coastguard Worker    _enable_profiler,
15*da0073e9SAndroid Build Coastguard Worker    _kineto_step,
16*da0073e9SAndroid Build Coastguard Worker    _prepare_profiler,
17*da0073e9SAndroid Build Coastguard Worker    _ProfilerResult,
18*da0073e9SAndroid Build Coastguard Worker    _supported_activities,
19*da0073e9SAndroid Build Coastguard Worker    _toggle_collection_dynamic,
20*da0073e9SAndroid Build Coastguard Worker    DeviceType,
21*da0073e9SAndroid Build Coastguard Worker    kineto_available,
22*da0073e9SAndroid Build Coastguard Worker    ProfilerActivity,
23*da0073e9SAndroid Build Coastguard Worker    ProfilerConfig,
24*da0073e9SAndroid Build Coastguard Worker    ProfilerState,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_util import (
27*da0073e9SAndroid Build Coastguard Worker    _filter_name,
28*da0073e9SAndroid Build Coastguard Worker    _filter_stack_entry,
29*da0073e9SAndroid Build Coastguard Worker    _rewrite_name,
30*da0073e9SAndroid Build Coastguard Worker    EventList,
31*da0073e9SAndroid Build Coastguard Worker    FunctionEvent,
32*da0073e9SAndroid Build Coastguard Worker    MEMORY_EVENT_NAME,
33*da0073e9SAndroid Build Coastguard Worker    MemRecordsAcc,
34*da0073e9SAndroid Build Coastguard Worker    OUT_OF_MEMORY_EVENT_NAME,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Workerfrom torch.futures import Future
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker__all__ = [
40*da0073e9SAndroid Build Coastguard Worker    "profile",
41*da0073e9SAndroid Build Coastguard Worker    "record_function",
42*da0073e9SAndroid Build Coastguard Worker    "emit_itt",
43*da0073e9SAndroid Build Coastguard Worker    "emit_nvtx",
44*da0073e9SAndroid Build Coastguard Worker    "load_nvprof",
45*da0073e9SAndroid Build Coastguard Worker    "EnforceUnique",
46*da0073e9SAndroid Build Coastguard Worker    "parse_nvprof_trace",
47*da0073e9SAndroid Build Coastguard Worker    "KinetoStepTracker",
48*da0073e9SAndroid Build Coastguard Worker    "EventList",
49*da0073e9SAndroid Build Coastguard Worker    "FunctionEvent",
50*da0073e9SAndroid Build Coastguard Worker    "MemRecordsAcc",
51*da0073e9SAndroid Build Coastguard Worker]
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workertry:
54*da0073e9SAndroid Build Coastguard Worker    # Available in Python >= 3.2
55*da0073e9SAndroid Build Coastguard Worker    from contextlib import ContextDecorator as _ContextDecorator
56*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
57*da0073e9SAndroid Build Coastguard Worker    import functools
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    class _ContextDecorator:  # type: ignore[no-redef]
60*da0073e9SAndroid Build Coastguard Worker        def __enter__(self):
61*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        def __exit__(self, exc_type, exc_val, exc_tb):
64*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        def __call__(self, func):
67*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(func)
68*da0073e9SAndroid Build Coastguard Worker            def wrapped(*args, **kwargs):
69*da0073e9SAndroid Build Coastguard Worker                with self:
70*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker            return wrapped
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker# global python state - whether profiler is currently enabled
76*da0073e9SAndroid Build Coastguard Worker# useful for fast python checks to reduce latency
77*da0073e9SAndroid Build Coastguard Worker_is_profiler_enabled: bool = False
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef _set_is_profiler_enabled(enable: bool):
81*da0073e9SAndroid Build Coastguard Worker    global _is_profiler_enabled
82*da0073e9SAndroid Build Coastguard Worker    _is_profiler_enabled = enable
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef _run_on_profiler_start():
86*da0073e9SAndroid Build Coastguard Worker    _set_is_profiler_enabled(True)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerdef _run_on_profiler_stop():
90*da0073e9SAndroid Build Coastguard Worker    _set_is_profiler_enabled(False)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker@dataclass
94*da0073e9SAndroid Build Coastguard Workerclass _ProfilerStats:
95*da0073e9SAndroid Build Coastguard Worker    "Profiler timing and stats used by developers to catch issues/regressions"
96*da0073e9SAndroid Build Coastguard Worker    profiling_window_duration_sec: float = 0
97*da0073e9SAndroid Build Coastguard Worker    number_of_events: int = 0
98*da0073e9SAndroid Build Coastguard Worker    profiler_prepare_call_duration_us: int = 0
99*da0073e9SAndroid Build Coastguard Worker    profiler_enable_call_duration_us: int = 0
100*da0073e9SAndroid Build Coastguard Worker    profiler_disable_call_duration_us: int = 0
101*da0073e9SAndroid Build Coastguard Worker    parse_kineto_call_duration_us: int = 0
102*da0073e9SAndroid Build Coastguard Worker    function_events_build_tree_call_duration_us: int = 0
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerclass profile:
106*da0073e9SAndroid Build Coastguard Worker    """Context manager that manages autograd profiler state and holds a summary of results.
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    Under the hood it just records events of functions being executed in C++ and
109*da0073e9SAndroid Build Coastguard Worker    exposes those events to Python. You can wrap any code into it and it will
110*da0073e9SAndroid Build Coastguard Worker    only report runtime of PyTorch functions.
111*da0073e9SAndroid Build Coastguard Worker    Note: profiler is thread local and is automatically propagated into the async tasks
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    Args:
114*da0073e9SAndroid Build Coastguard Worker        enabled (bool, optional): Setting this to False makes this context manager a no-op.
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        use_cuda (bool, optional): Enables timing of CUDA events as well
117*da0073e9SAndroid Build Coastguard Worker            using the cudaEvent API. (will be deprecated)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        use_device (str, optional): Enables timing of device events.
120*da0073e9SAndroid Build Coastguard Worker            Adds approximately 4us of overhead to each tensor operation when use cuda.
121*da0073e9SAndroid Build Coastguard Worker            The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'.
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        record_shapes (bool, optional): If shapes recording is set, information
124*da0073e9SAndroid Build Coastguard Worker            about input dimensions will be collected. This allows one to see which
125*da0073e9SAndroid Build Coastguard Worker            dimensions have been used under the hood and further group by them
126*da0073e9SAndroid Build Coastguard Worker            using prof.key_averages(group_by_input_shape=True). Please note that
127*da0073e9SAndroid Build Coastguard Worker            shape recording might skew your profiling data. It is recommended to
128*da0073e9SAndroid Build Coastguard Worker            use separate runs with and without shape recording to validate the timing.
129*da0073e9SAndroid Build Coastguard Worker            Most likely the skew will be negligible for bottom most events (in a case
130*da0073e9SAndroid Build Coastguard Worker            of nested function calls). But for higher level functions the total
131*da0073e9SAndroid Build Coastguard Worker            self cpu time might be artificially increased because of the shape
132*da0073e9SAndroid Build Coastguard Worker            collection.
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        with_flops (bool, optional): If with_flops is set, the profiler will estimate
135*da0073e9SAndroid Build Coastguard Worker            the FLOPs (floating point operations) value using the operator's input shape.
136*da0073e9SAndroid Build Coastguard Worker            This allows one to estimate the hardware performance. Currently,
137*da0073e9SAndroid Build Coastguard Worker            this option only works for the matrix multiplication and 2D convolution operators.
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        profile_memory (bool, optional): track tensor memory allocation/deallocation.
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        with_stack (bool, optional): record source information (file and line number) for the ops.
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        with_modules (bool): record module hierarchy (including function names)
144*da0073e9SAndroid Build Coastguard Worker            corresponding to the callstack of the op. e.g. If module A's forward call's
145*da0073e9SAndroid Build Coastguard Worker            module B's forward which contains an aten::add op,
146*da0073e9SAndroid Build Coastguard Worker            then aten::add's module hierarchy is A.B
147*da0073e9SAndroid Build Coastguard Worker            Note that this support exist, at the moment, only for TorchScript models
148*da0073e9SAndroid Build Coastguard Worker            and not eager mode models.
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
153*da0073e9SAndroid Build Coastguard Worker            ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        experimental_config (_ExperimentalConfig) : A set of experimental options
156*da0073e9SAndroid Build Coastguard Worker            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    .. warning:
162*da0073e9SAndroid Build Coastguard Worker        Enabling memory profiling or source attribution incurs additional profiler
163*da0073e9SAndroid Build Coastguard Worker        overhead
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    .. warning:
166*da0073e9SAndroid Build Coastguard Worker        This context managers should not be called recursively, i.e. no nested
167*da0073e9SAndroid Build Coastguard Worker        instances are allowed
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    .. warning:
170*da0073e9SAndroid Build Coastguard Worker        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
171*da0073e9SAndroid Build Coastguard Worker        one cannot use the profiler with ``use_device = 'cuda'`` to benchmark
172*da0073e9SAndroid Build Coastguard Worker        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
173*da0073e9SAndroid Build Coastguard Worker        please use ``use_device = None`` or ``num_workers = 0``.
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    Example:
176*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
177*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
178*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn((1, 1), requires_grad=True)
179*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.profiler.profile() as prof:
180*da0073e9SAndroid Build Coastguard Worker        >>>     for _ in range(100):  # any normal python code, really!
181*da0073e9SAndroid Build Coastguard Worker        >>>         y = x ** 2
182*da0073e9SAndroid Build Coastguard Worker        >>>         y.backward()
183*da0073e9SAndroid Build Coastguard Worker        >>> # NOTE: some columns were removed for brevity
184*da0073e9SAndroid Build Coastguard Worker        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
185*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
186*da0073e9SAndroid Build Coastguard Worker        Name                                 Self CPU total   CPU time avg     Number of Calls
187*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
188*da0073e9SAndroid Build Coastguard Worker        mul                                  32.048ms         32.048ms         200
189*da0073e9SAndroid Build Coastguard Worker        pow                                  27.041ms         27.041ms         200
190*da0073e9SAndroid Build Coastguard Worker        PowBackward0                         9.727ms          55.483ms         100
191*da0073e9SAndroid Build Coastguard Worker        torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
192*da0073e9SAndroid Build Coastguard Worker        torch::autograd::GraphRoot           691.816us        691.816us        100
193*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    """
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def __init__(
198*da0073e9SAndroid Build Coastguard Worker        self,
199*da0073e9SAndroid Build Coastguard Worker        enabled=True,
200*da0073e9SAndroid Build Coastguard Worker        *,
201*da0073e9SAndroid Build Coastguard Worker        use_cuda=False,  # Deprecated
202*da0073e9SAndroid Build Coastguard Worker        use_device=None,
203*da0073e9SAndroid Build Coastguard Worker        record_shapes=False,
204*da0073e9SAndroid Build Coastguard Worker        with_flops=False,
205*da0073e9SAndroid Build Coastguard Worker        profile_memory=False,
206*da0073e9SAndroid Build Coastguard Worker        with_stack=False,
207*da0073e9SAndroid Build Coastguard Worker        with_modules=False,
208*da0073e9SAndroid Build Coastguard Worker        use_kineto=False,
209*da0073e9SAndroid Build Coastguard Worker        use_cpu=True,
210*da0073e9SAndroid Build Coastguard Worker        experimental_config=None,
211*da0073e9SAndroid Build Coastguard Worker        acc_events=False,
212*da0073e9SAndroid Build Coastguard Worker    ):
213*da0073e9SAndroid Build Coastguard Worker        self.enabled: bool = enabled
214*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
215*da0073e9SAndroid Build Coastguard Worker            return
216*da0073e9SAndroid Build Coastguard Worker        self.use_cuda = use_cuda
217*da0073e9SAndroid Build Coastguard Worker        if self.use_cuda:
218*da0073e9SAndroid Build Coastguard Worker            warn(
219*da0073e9SAndroid Build Coastguard Worker                "The attribute `use_cuda` will be deprecated soon, "
220*da0073e9SAndroid Build Coastguard Worker                "please use ``use_device = 'cuda'`` instead.",
221*da0073e9SAndroid Build Coastguard Worker                FutureWarning,
222*da0073e9SAndroid Build Coastguard Worker                stacklevel=2,
223*da0073e9SAndroid Build Coastguard Worker            )
224*da0073e9SAndroid Build Coastguard Worker            self.use_device: Optional[str] = "cuda"
225*da0073e9SAndroid Build Coastguard Worker        else:
226*da0073e9SAndroid Build Coastguard Worker            self.use_device = use_device
227*da0073e9SAndroid Build Coastguard Worker        # TODO Consider changing _function_events into data structure with size cap
228*da0073e9SAndroid Build Coastguard Worker        self._function_events: Optional[EventList] = None
229*da0073e9SAndroid Build Coastguard Worker        self._old_function_events: Optional[EventList] = None
230*da0073e9SAndroid Build Coastguard Worker        # Function event processing is done lazily
231*da0073e9SAndroid Build Coastguard Worker        self._needs_processing = False
232*da0073e9SAndroid Build Coastguard Worker        self.entered = False
233*da0073e9SAndroid Build Coastguard Worker        self.record_shapes = record_shapes
234*da0073e9SAndroid Build Coastguard Worker        self.with_flops = with_flops
235*da0073e9SAndroid Build Coastguard Worker        self.record_shapes |= self.with_flops
236*da0073e9SAndroid Build Coastguard Worker        self.profile_memory = profile_memory
237*da0073e9SAndroid Build Coastguard Worker        self.with_stack = with_stack
238*da0073e9SAndroid Build Coastguard Worker        self.with_modules = with_modules
239*da0073e9SAndroid Build Coastguard Worker        self.use_cpu = use_cpu
240*da0073e9SAndroid Build Coastguard Worker        self.acc_events = acc_events
241*da0073e9SAndroid Build Coastguard Worker        if experimental_config is None:
242*da0073e9SAndroid Build Coastguard Worker            experimental_config = _ExperimentalConfig()
243*da0073e9SAndroid Build Coastguard Worker        self.experimental_config = experimental_config
244*da0073e9SAndroid Build Coastguard Worker        self.kineto_results: Optional[_ProfilerResult] = None
245*da0073e9SAndroid Build Coastguard Worker        self.profiling_start_time_ns = 0
246*da0073e9SAndroid Build Coastguard Worker        self.profiling_end_time_ns = 0
247*da0073e9SAndroid Build Coastguard Worker        self._stats = _ProfilerStats()
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        if not self.use_cpu:
250*da0073e9SAndroid Build Coastguard Worker            assert (
251*da0073e9SAndroid Build Coastguard Worker                use_kineto
252*da0073e9SAndroid Build Coastguard Worker            ), "Device-only events supported only with Kineto (use_kineto=True)"
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        if self.use_device is not None:
255*da0073e9SAndroid Build Coastguard Worker            VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"]
256*da0073e9SAndroid Build Coastguard Worker            if _get_privateuse1_backend_name() != "privateuseone":
257*da0073e9SAndroid Build Coastguard Worker                VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name())
258*da0073e9SAndroid Build Coastguard Worker            if self.use_device not in VALID_DEVICE_OPTIONS:
259*da0073e9SAndroid Build Coastguard Worker                warn(f"The {self.use_device} is not a valid device option.")
260*da0073e9SAndroid Build Coastguard Worker                self.use_device = None
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker            if self.use_device == "cuda" and not torch.cuda.is_available():
263*da0073e9SAndroid Build Coastguard Worker                warn("CUDA is not available, disabling CUDA profiling")
264*da0073e9SAndroid Build Coastguard Worker                self.use_cuda = False
265*da0073e9SAndroid Build Coastguard Worker                self.use_device = None
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker            if self.use_device == "xpu" and not torch.xpu.is_available():
268*da0073e9SAndroid Build Coastguard Worker                warn("XPU is not available, disabling XPU profiling")
269*da0073e9SAndroid Build Coastguard Worker                self.use_device = None
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        self.kineto_activities = set()
272*da0073e9SAndroid Build Coastguard Worker        if self.use_cpu:
273*da0073e9SAndroid Build Coastguard Worker            self.kineto_activities.add(ProfilerActivity.CPU)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        self.profiler_kind = ProfilerState.KINETO
276*da0073e9SAndroid Build Coastguard Worker        if self.use_device == "cuda":
277*da0073e9SAndroid Build Coastguard Worker            if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
278*da0073e9SAndroid Build Coastguard Worker                assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
279*da0073e9SAndroid Build Coastguard Worker                self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
280*da0073e9SAndroid Build Coastguard Worker            else:
281*da0073e9SAndroid Build Coastguard Worker                self.kineto_activities.add(ProfilerActivity.CUDA)
282*da0073e9SAndroid Build Coastguard Worker        elif self.use_device == "xpu":
283*da0073e9SAndroid Build Coastguard Worker            assert (
284*da0073e9SAndroid Build Coastguard Worker                use_kineto and ProfilerActivity.XPU in _supported_activities()
285*da0073e9SAndroid Build Coastguard Worker            ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
286*da0073e9SAndroid Build Coastguard Worker            self.kineto_activities.add(ProfilerActivity.XPU)
287*da0073e9SAndroid Build Coastguard Worker        elif self.use_device == "mtia":
288*da0073e9SAndroid Build Coastguard Worker            assert (
289*da0073e9SAndroid Build Coastguard Worker                use_kineto and ProfilerActivity.MTIA in _supported_activities()
290*da0073e9SAndroid Build Coastguard Worker            ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices."
291*da0073e9SAndroid Build Coastguard Worker            self.kineto_activities.add(ProfilerActivity.MTIA)
292*da0073e9SAndroid Build Coastguard Worker        elif self.use_device is not None and self.use_device != "privateuseone":
293*da0073e9SAndroid Build Coastguard Worker            if (
294*da0073e9SAndroid Build Coastguard Worker                not use_kineto
295*da0073e9SAndroid Build Coastguard Worker                or ProfilerActivity.PrivateUse1 not in _supported_activities()
296*da0073e9SAndroid Build Coastguard Worker            ):
297*da0073e9SAndroid Build Coastguard Worker                assert (
298*da0073e9SAndroid Build Coastguard Worker                    self.use_cpu
299*da0073e9SAndroid Build Coastguard Worker                ), "Legacy custombackend profiling requires use_cpu=True"
300*da0073e9SAndroid Build Coastguard Worker                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
301*da0073e9SAndroid Build Coastguard Worker            else:
302*da0073e9SAndroid Build Coastguard Worker                self.kineto_activities.add(ProfilerActivity.PrivateUse1)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        assert (
305*da0073e9SAndroid Build Coastguard Worker            len(self.kineto_activities) > 0
306*da0073e9SAndroid Build Coastguard Worker        ), "No activities specified for the profiler"
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    def config(self):
309*da0073e9SAndroid Build Coastguard Worker        return ProfilerConfig(
310*da0073e9SAndroid Build Coastguard Worker            self.profiler_kind,
311*da0073e9SAndroid Build Coastguard Worker            self.record_shapes,
312*da0073e9SAndroid Build Coastguard Worker            self.profile_memory,
313*da0073e9SAndroid Build Coastguard Worker            self.with_stack,
314*da0073e9SAndroid Build Coastguard Worker            self.with_flops,
315*da0073e9SAndroid Build Coastguard Worker            self.with_modules,
316*da0073e9SAndroid Build Coastguard Worker            self.experimental_config,
317*da0073e9SAndroid Build Coastguard Worker        )
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
320*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
321*da0073e9SAndroid Build Coastguard Worker            return
322*da0073e9SAndroid Build Coastguard Worker        if self.entered:
323*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Profiler context manager is not reentrant")
324*da0073e9SAndroid Build Coastguard Worker        self._prepare_trace()
325*da0073e9SAndroid Build Coastguard Worker        self._start_trace()
326*da0073e9SAndroid Build Coastguard Worker        return self
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    def _prepare_trace(self):
329*da0073e9SAndroid Build Coastguard Worker        self.entered = True
330*da0073e9SAndroid Build Coastguard Worker        t0 = perf_counter_ns()
331*da0073e9SAndroid Build Coastguard Worker        _prepare_profiler(self.config(), self.kineto_activities)
332*da0073e9SAndroid Build Coastguard Worker        t1 = perf_counter_ns()
333*da0073e9SAndroid Build Coastguard Worker        self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def _start_trace(self):
336*da0073e9SAndroid Build Coastguard Worker        self.entered = True
337*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_start()
338*da0073e9SAndroid Build Coastguard Worker        t0 = perf_counter_ns()
339*da0073e9SAndroid Build Coastguard Worker        _enable_profiler(self.config(), self.kineto_activities)
340*da0073e9SAndroid Build Coastguard Worker        t1 = perf_counter_ns()
341*da0073e9SAndroid Build Coastguard Worker        self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000)
342*da0073e9SAndroid Build Coastguard Worker        self.profiling_start_time_ns = t1
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
345*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
346*da0073e9SAndroid Build Coastguard Worker            return
347*da0073e9SAndroid Build Coastguard Worker        if self.use_device and hasattr(torch, self.use_device):
348*da0073e9SAndroid Build Coastguard Worker            device_module = getattr(torch, self.use_device)
349*da0073e9SAndroid Build Coastguard Worker            if hasattr(device_module, "synchronize"):
350*da0073e9SAndroid Build Coastguard Worker                device_module.synchronize()
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        if self._function_events and self.acc_events:
353*da0073e9SAndroid Build Coastguard Worker            self._old_function_events = self._function_events
354*da0073e9SAndroid Build Coastguard Worker        self._function_events = None
355*da0073e9SAndroid Build Coastguard Worker        self._needs_processing = True
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        t0 = perf_counter_ns()
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        self.kineto_results = _disable_profiler()
360*da0073e9SAndroid Build Coastguard Worker        t1 = perf_counter_ns()
361*da0073e9SAndroid Build Coastguard Worker        self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000)
362*da0073e9SAndroid Build Coastguard Worker        self.profiling_end_time_ns = t0
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_stop()
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        self._stats.profiling_window_duration_sec = (
367*da0073e9SAndroid Build Coastguard Worker            (self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9
368*da0073e9SAndroid Build Coastguard Worker        )
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        # If we plan to accumulate events we should post process the function events
371*da0073e9SAndroid Build Coastguard Worker        # right away to retain the state across mulitple start/stop calls
372*da0073e9SAndroid Build Coastguard Worker        if self.acc_events:
373*da0073e9SAndroid Build Coastguard Worker            self._ensure_function_events()
374*da0073e9SAndroid Build Coastguard Worker        return False
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
377*da0073e9SAndroid Build Coastguard Worker        if self._needs_processing:
378*da0073e9SAndroid Build Coastguard Worker            self._ensure_function_events()
379*da0073e9SAndroid Build Coastguard Worker        if self._function_events is None:
380*da0073e9SAndroid Build Coastguard Worker            return "<unfinished torch.autograd.profile>"
381*da0073e9SAndroid Build Coastguard Worker        return repr(self._function_events)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
384*da0073e9SAndroid Build Coastguard Worker        if self._needs_processing:
385*da0073e9SAndroid Build Coastguard Worker            self._ensure_function_events()
386*da0073e9SAndroid Build Coastguard Worker        if self._function_events is None:
387*da0073e9SAndroid Build Coastguard Worker            return "<unfinished torch.autograd.profile>"
388*da0073e9SAndroid Build Coastguard Worker        return str(self._function_events)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker    def _ensure_function_events(self):
391*da0073e9SAndroid Build Coastguard Worker        """Process function events lazily if required"""
392*da0073e9SAndroid Build Coastguard Worker        if self._function_events is not None:
393*da0073e9SAndroid Build Coastguard Worker            return
394*da0073e9SAndroid Build Coastguard Worker        self._needs_processing = False
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        t0 = perf_counter_ns()
397*da0073e9SAndroid Build Coastguard Worker        parsed_results = []
398*da0073e9SAndroid Build Coastguard Worker        if self.kineto_results:
399*da0073e9SAndroid Build Coastguard Worker            parsed_results = self._parse_kineto_results(self.kineto_results)
400*da0073e9SAndroid Build Coastguard Worker        t1 = perf_counter_ns()
401*da0073e9SAndroid Build Coastguard Worker        self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000)
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        self._function_events = EventList(
404*da0073e9SAndroid Build Coastguard Worker            parsed_results,
405*da0073e9SAndroid Build Coastguard Worker            use_device=self.use_device,
406*da0073e9SAndroid Build Coastguard Worker            profile_memory=self.profile_memory,
407*da0073e9SAndroid Build Coastguard Worker            with_flops=self.with_flops,
408*da0073e9SAndroid Build Coastguard Worker        )
409*da0073e9SAndroid Build Coastguard Worker        t0 = perf_counter_ns()
410*da0073e9SAndroid Build Coastguard Worker        self._function_events._build_tree()
411*da0073e9SAndroid Build Coastguard Worker        t1 = perf_counter_ns()
412*da0073e9SAndroid Build Coastguard Worker        self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000)
413*da0073e9SAndroid Build Coastguard Worker        self._stats.number_of_events = len(self._function_events)
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker        if self._old_function_events and self.acc_events:
416*da0073e9SAndroid Build Coastguard Worker            for evt in self._old_function_events:
417*da0073e9SAndroid Build Coastguard Worker                self._function_events.append(evt)
418*da0073e9SAndroid Build Coastguard Worker            self._old_function_events = None
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        if self._function_events is None:
421*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Profiler didn't finish running")
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    @property
424*da0073e9SAndroid Build Coastguard Worker    def function_events(self):
425*da0073e9SAndroid Build Coastguard Worker        if self._function_events is None or self._needs_processing:
426*da0073e9SAndroid Build Coastguard Worker            self._ensure_function_events()
427*da0073e9SAndroid Build Coastguard Worker        return self._function_events
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    def table(
430*da0073e9SAndroid Build Coastguard Worker        self,
431*da0073e9SAndroid Build Coastguard Worker        sort_by=None,
432*da0073e9SAndroid Build Coastguard Worker        row_limit=100,
433*da0073e9SAndroid Build Coastguard Worker        max_src_column_width=75,
434*da0073e9SAndroid Build Coastguard Worker        max_name_column_width=55,
435*da0073e9SAndroid Build Coastguard Worker        max_shapes_column_width=80,
436*da0073e9SAndroid Build Coastguard Worker        header=None,
437*da0073e9SAndroid Build Coastguard Worker        top_level_events_only=False,
438*da0073e9SAndroid Build Coastguard Worker    ):
439*da0073e9SAndroid Build Coastguard Worker        self._ensure_function_events()
440*da0073e9SAndroid Build Coastguard Worker        assert self._function_events is not None
441*da0073e9SAndroid Build Coastguard Worker        return self._function_events.table(
442*da0073e9SAndroid Build Coastguard Worker            sort_by=sort_by,
443*da0073e9SAndroid Build Coastguard Worker            row_limit=row_limit,
444*da0073e9SAndroid Build Coastguard Worker            max_src_column_width=max_src_column_width,
445*da0073e9SAndroid Build Coastguard Worker            max_name_column_width=max_name_column_width,
446*da0073e9SAndroid Build Coastguard Worker            max_shapes_column_width=max_shapes_column_width,
447*da0073e9SAndroid Build Coastguard Worker            header=header,
448*da0073e9SAndroid Build Coastguard Worker            top_level_events_only=top_level_events_only,
449*da0073e9SAndroid Build Coastguard Worker        )
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    table.__doc__ = EventList.table.__doc__
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    def export_chrome_trace(self, path):
454*da0073e9SAndroid Build Coastguard Worker        """
455*da0073e9SAndroid Build Coastguard Worker        Exports the collected trace in Chrome JSON format. If kineto is enabled, only
456*da0073e9SAndroid Build Coastguard Worker        last cycle in schedule is exported.
457*da0073e9SAndroid Build Coastguard Worker        """
458*da0073e9SAndroid Build Coastguard Worker        if kineto_available():
459*da0073e9SAndroid Build Coastguard Worker            self.kineto_results.save(path)  # type: ignore[union-attr]
460*da0073e9SAndroid Build Coastguard Worker        else:
461*da0073e9SAndroid Build Coastguard Worker            self._ensure_function_events()
462*da0073e9SAndroid Build Coastguard Worker            return self._function_events.export_chrome_trace(path)  # type: ignore[union-attr]
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
467*da0073e9SAndroid Build Coastguard Worker        self._ensure_function_events()
468*da0073e9SAndroid Build Coastguard Worker        assert self._function_events is not None, "Expected profiling results"
469*da0073e9SAndroid Build Coastguard Worker        assert self.with_stack, "export_stacks() requires with_stack=True"
470*da0073e9SAndroid Build Coastguard Worker        return self._function_events.export_stacks(path, metric)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    def toggle_collection_dynamic(
473*da0073e9SAndroid Build Coastguard Worker        self, enabled: bool, activities: Iterable[ProfilerActivity]
474*da0073e9SAndroid Build Coastguard Worker    ):
475*da0073e9SAndroid Build Coastguard Worker        """
476*da0073e9SAndroid Build Coastguard Worker        Toggles the collection of activities for the current profiler instance.
477*da0073e9SAndroid Build Coastguard Worker        """
478*da0073e9SAndroid Build Coastguard Worker        return _toggle_collection_dynamic(enabled, set(activities))
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
481*da0073e9SAndroid Build Coastguard Worker        self._ensure_function_events()
482*da0073e9SAndroid Build Coastguard Worker        assert self._function_events is not None, "Expected profiling results"
483*da0073e9SAndroid Build Coastguard Worker        return self._function_events.key_averages(
484*da0073e9SAndroid Build Coastguard Worker            group_by_input_shape, group_by_stack_n
485*da0073e9SAndroid Build Coastguard Worker        )
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    key_averages.__doc__ = EventList.key_averages.__doc__
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    def total_average(self):
490*da0073e9SAndroid Build Coastguard Worker        self._ensure_function_events()
491*da0073e9SAndroid Build Coastguard Worker        assert self._function_events is not None, "Expected profiling results"
492*da0073e9SAndroid Build Coastguard Worker        return self._function_events.total_average()
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    total_average.__doc__ = EventList.total_average.__doc__
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    @property
497*da0073e9SAndroid Build Coastguard Worker    def self_cpu_time_total(self):
498*da0073e9SAndroid Build Coastguard Worker        """Returns total time spent on CPU.
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        The total time is a sum of all self times across all the events.
501*da0073e9SAndroid Build Coastguard Worker        """
502*da0073e9SAndroid Build Coastguard Worker        self._ensure_function_events()
503*da0073e9SAndroid Build Coastguard Worker        assert self._function_events is not None
504*da0073e9SAndroid Build Coastguard Worker        return self._function_events.self_cpu_time_total
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker    def _parse_kineto_results(self, result: _ProfilerResult):
507*da0073e9SAndroid Build Coastguard Worker        # result.events() has most of the events - PyTorch op-level and device-level events
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        trace_start_ns = result.trace_start_ns()
510*da0073e9SAndroid Build Coastguard Worker        mem_records = [
511*da0073e9SAndroid Build Coastguard Worker            [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
512*da0073e9SAndroid Build Coastguard Worker        ]
513*da0073e9SAndroid Build Coastguard Worker        oom_records = [
514*da0073e9SAndroid Build Coastguard Worker            evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
515*da0073e9SAndroid Build Coastguard Worker        ]
516*da0073e9SAndroid Build Coastguard Worker        mem_records_acc = MemRecordsAcc(mem_records)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        def _cpu_memory_usage(mem_record):
519*da0073e9SAndroid Build Coastguard Worker            return (
520*da0073e9SAndroid Build Coastguard Worker                mem_record.nbytes()
521*da0073e9SAndroid Build Coastguard Worker                if mem_record.device_type()
522*da0073e9SAndroid Build Coastguard Worker                in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
523*da0073e9SAndroid Build Coastguard Worker                else 0
524*da0073e9SAndroid Build Coastguard Worker            )
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker        def _device_memory_usage(mem_record):
527*da0073e9SAndroid Build Coastguard Worker            return (
528*da0073e9SAndroid Build Coastguard Worker                mem_record.nbytes()
529*da0073e9SAndroid Build Coastguard Worker                if mem_record.device_type()
530*da0073e9SAndroid Build Coastguard Worker                in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
531*da0073e9SAndroid Build Coastguard Worker                else 0
532*da0073e9SAndroid Build Coastguard Worker            )
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        # Create and return FunctionEvent list, which contains all function events
535*da0073e9SAndroid Build Coastguard Worker        # Here 2 function events are created:
536*da0073e9SAndroid Build Coastguard Worker        # all_function_events contains all events associated with each kineto event from result
537*da0073e9SAndroid Build Coastguard Worker        all_function_events = []
538*da0073e9SAndroid Build Coastguard Worker        # frontend_function_events contains the events in aten or torch frontend level,
539*da0073e9SAndroid Build Coastguard Worker        # whose correlation id is 0
540*da0073e9SAndroid Build Coastguard Worker        frontend_function_events = []
541*da0073e9SAndroid Build Coastguard Worker        device_corr_map: Dict[int, List[FunctionEvent]] = {}
542*da0073e9SAndroid Build Coastguard Worker        max_evt_id = 0
543*da0073e9SAndroid Build Coastguard Worker        for kineto_event in result.events():
544*da0073e9SAndroid Build Coastguard Worker            if _filter_name(kineto_event.name()):
545*da0073e9SAndroid Build Coastguard Worker                continue
546*da0073e9SAndroid Build Coastguard Worker            rel_start_ns = kineto_event.start_ns() - trace_start_ns
547*da0073e9SAndroid Build Coastguard Worker            rel_end_ns = kineto_event.end_ns() - trace_start_ns
548*da0073e9SAndroid Build Coastguard Worker            abs_end_ns = kineto_event.end_ns()
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker            cpu_memory_usage = 0
551*da0073e9SAndroid Build Coastguard Worker            device_memory_usage = 0
552*da0073e9SAndroid Build Coastguard Worker            if kineto_event.device_type() == DeviceType.CPU:
553*da0073e9SAndroid Build Coastguard Worker                # find the corresponding memory allocation events
554*da0073e9SAndroid Build Coastguard Worker                for mem_record in mem_records_acc.in_interval(
555*da0073e9SAndroid Build Coastguard Worker                    kineto_event.start_ns() / 1000, abs_end_ns / 1000
556*da0073e9SAndroid Build Coastguard Worker                ):
557*da0073e9SAndroid Build Coastguard Worker                    cpu_memory_usage += _cpu_memory_usage(mem_record[0])
558*da0073e9SAndroid Build Coastguard Worker                    device_memory_usage += _device_memory_usage(mem_record[0])
559*da0073e9SAndroid Build Coastguard Worker                    mem_record[1] = True
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker            is_async = kineto_event.is_async() or (
562*da0073e9SAndroid Build Coastguard Worker                kineto_event.start_thread_id() != kineto_event.end_thread_id()
563*da0073e9SAndroid Build Coastguard Worker            )
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker            fe = FunctionEvent(
566*da0073e9SAndroid Build Coastguard Worker                id=kineto_event.correlation_id(),
567*da0073e9SAndroid Build Coastguard Worker                name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
568*da0073e9SAndroid Build Coastguard Worker                trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
569*da0073e9SAndroid Build Coastguard Worker                thread=kineto_event.start_thread_id(),
570*da0073e9SAndroid Build Coastguard Worker                start_us=rel_start_ns / 1000,
571*da0073e9SAndroid Build Coastguard Worker                end_us=rel_end_ns / 1000,
572*da0073e9SAndroid Build Coastguard Worker                fwd_thread=kineto_event.fwd_thread_id(),
573*da0073e9SAndroid Build Coastguard Worker                input_shapes=kineto_event.shapes(),
574*da0073e9SAndroid Build Coastguard Worker                concrete_inputs=kineto_event.concrete_inputs(),
575*da0073e9SAndroid Build Coastguard Worker                kwinputs=kineto_event.kwinputs(),
576*da0073e9SAndroid Build Coastguard Worker                stack=[
577*da0073e9SAndroid Build Coastguard Worker                    entry
578*da0073e9SAndroid Build Coastguard Worker                    for entry in kineto_event.stack()
579*da0073e9SAndroid Build Coastguard Worker                    if _filter_stack_entry(entry)
580*da0073e9SAndroid Build Coastguard Worker                ],
581*da0073e9SAndroid Build Coastguard Worker                scope=kineto_event.scope(),
582*da0073e9SAndroid Build Coastguard Worker                use_device=self.use_device,
583*da0073e9SAndroid Build Coastguard Worker                cpu_memory_usage=cpu_memory_usage,
584*da0073e9SAndroid Build Coastguard Worker                device_memory_usage=device_memory_usage,
585*da0073e9SAndroid Build Coastguard Worker                is_async=is_async,
586*da0073e9SAndroid Build Coastguard Worker                sequence_nr=kineto_event.sequence_nr(),
587*da0073e9SAndroid Build Coastguard Worker                device_type=kineto_event.device_type(),
588*da0073e9SAndroid Build Coastguard Worker                device_index=kineto_event.device_index(),
589*da0073e9SAndroid Build Coastguard Worker                device_resource_id=kineto_event.device_resource_id(),
590*da0073e9SAndroid Build Coastguard Worker                flops=kineto_event.flops(),
591*da0073e9SAndroid Build Coastguard Worker                is_user_annotation=kineto_event.is_user_annotation(),
592*da0073e9SAndroid Build Coastguard Worker            )
593*da0073e9SAndroid Build Coastguard Worker            max_evt_id = max(max_evt_id, fe.id)
594*da0073e9SAndroid Build Coastguard Worker            if fe.device_type == DeviceType.CPU and not fe.is_async:
595*da0073e9SAndroid Build Coastguard Worker                if self.use_device == "privateuseone":
596*da0073e9SAndroid Build Coastguard Worker                    privateuse1_time = kineto_event.privateuse1_elapsed_us()
597*da0073e9SAndroid Build Coastguard Worker                    if privateuse1_time > 0:
598*da0073e9SAndroid Build Coastguard Worker                        fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
599*da0073e9SAndroid Build Coastguard Worker                        fe.is_legacy = True
600*da0073e9SAndroid Build Coastguard Worker                elif self.use_device == "cuda":
601*da0073e9SAndroid Build Coastguard Worker                    # Check if we have CUDA time as a fallback
602*da0073e9SAndroid Build Coastguard Worker                    cuda_time = kineto_event.cuda_elapsed_us()
603*da0073e9SAndroid Build Coastguard Worker                    if cuda_time > 0:
604*da0073e9SAndroid Build Coastguard Worker                        fe.append_kernel(fe.name, fe.device_index, cuda_time)
605*da0073e9SAndroid Build Coastguard Worker                        fe.is_legacy = True
606*da0073e9SAndroid Build Coastguard Worker            all_function_events.append(fe)
607*da0073e9SAndroid Build Coastguard Worker            corr_id = kineto_event.linked_correlation_id()
608*da0073e9SAndroid Build Coastguard Worker            if corr_id > 0:
609*da0073e9SAndroid Build Coastguard Worker                if corr_id not in device_corr_map:
610*da0073e9SAndroid Build Coastguard Worker                    device_corr_map[corr_id] = []
611*da0073e9SAndroid Build Coastguard Worker                device_corr_map[corr_id].append(fe)
612*da0073e9SAndroid Build Coastguard Worker            elif corr_id == 0:
613*da0073e9SAndroid Build Coastguard Worker                frontend_function_events.append(fe)
614*da0073e9SAndroid Build Coastguard Worker            else:
615*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
616*da0073e9SAndroid Build Coastguard Worker                    f"Got negative correlation id {corr_id} in profiler post processing"
617*da0073e9SAndroid Build Coastguard Worker                )
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        # associate device kernels and device runtime (CPU) with CPU events
620*da0073e9SAndroid Build Coastguard Worker        for fe in frontend_function_events:
621*da0073e9SAndroid Build Coastguard Worker            if (
622*da0073e9SAndroid Build Coastguard Worker                fe.device_type == DeviceType.CPU
623*da0073e9SAndroid Build Coastguard Worker                and not fe.is_async
624*da0073e9SAndroid Build Coastguard Worker                and fe.id in device_corr_map
625*da0073e9SAndroid Build Coastguard Worker            ):
626*da0073e9SAndroid Build Coastguard Worker                for f_evt in device_corr_map[fe.id]:
627*da0073e9SAndroid Build Coastguard Worker                    if (
628*da0073e9SAndroid Build Coastguard Worker                        f_evt.device_type == DeviceType.CUDA
629*da0073e9SAndroid Build Coastguard Worker                        or f_evt.device_type == DeviceType.PrivateUse1
630*da0073e9SAndroid Build Coastguard Worker                    ):
631*da0073e9SAndroid Build Coastguard Worker                        fe.append_kernel(
632*da0073e9SAndroid Build Coastguard Worker                            f_evt.name,
633*da0073e9SAndroid Build Coastguard Worker                            f_evt.device_index,
634*da0073e9SAndroid Build Coastguard Worker                            f_evt.time_range.end - f_evt.time_range.start,
635*da0073e9SAndroid Build Coastguard Worker                        )
636*da0073e9SAndroid Build Coastguard Worker                    elif f_evt.device_type == DeviceType.CPU:
637*da0073e9SAndroid Build Coastguard Worker                        # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated
638*da0073e9SAndroid Build Coastguard Worker                        # with the 'thread' of the corresponding linked PyTorch event to properly track
639*da0073e9SAndroid Build Coastguard Worker                        # parents and children
640*da0073e9SAndroid Build Coastguard Worker                        f_evt.thread = fe.thread
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        def createFunctionEventForMemoryEvents(evt):
643*da0073e9SAndroid Build Coastguard Worker            rel_start_ns = evt.start_ns() - trace_start_ns
644*da0073e9SAndroid Build Coastguard Worker            fe = FunctionEvent(
645*da0073e9SAndroid Build Coastguard Worker                id=max_evt_id,
646*da0073e9SAndroid Build Coastguard Worker                name=evt.name(),
647*da0073e9SAndroid Build Coastguard Worker                trace_name=None,  # not outputting in the trace
648*da0073e9SAndroid Build Coastguard Worker                thread=evt.start_thread_id(),
649*da0073e9SAndroid Build Coastguard Worker                start_us=rel_start_ns / 1000,
650*da0073e9SAndroid Build Coastguard Worker                end_us=rel_start_ns / 1000,  # no duration
651*da0073e9SAndroid Build Coastguard Worker                fwd_thread=evt.start_thread_id(),
652*da0073e9SAndroid Build Coastguard Worker                input_shapes=[],
653*da0073e9SAndroid Build Coastguard Worker                stack=[],
654*da0073e9SAndroid Build Coastguard Worker                scope=0,  # RecordScope::FUNCTION
655*da0073e9SAndroid Build Coastguard Worker                use_device=self.use_device,
656*da0073e9SAndroid Build Coastguard Worker                cpu_memory_usage=_cpu_memory_usage(evt),
657*da0073e9SAndroid Build Coastguard Worker                device_memory_usage=_device_memory_usage(evt),
658*da0073e9SAndroid Build Coastguard Worker                is_async=False,
659*da0073e9SAndroid Build Coastguard Worker                sequence_nr=-1,
660*da0073e9SAndroid Build Coastguard Worker                device_type=DeviceType.CPU,
661*da0073e9SAndroid Build Coastguard Worker                device_index=0,
662*da0073e9SAndroid Build Coastguard Worker            )
663*da0073e9SAndroid Build Coastguard Worker            return fe
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker        # output top-level memory events
666*da0073e9SAndroid Build Coastguard Worker        for mem_record in mem_records:
667*da0073e9SAndroid Build Coastguard Worker            if not mem_record[1]:
668*da0073e9SAndroid Build Coastguard Worker                max_evt_id += 1
669*da0073e9SAndroid Build Coastguard Worker                fe = createFunctionEventForMemoryEvents(mem_record[0])
670*da0073e9SAndroid Build Coastguard Worker                all_function_events.append(fe)
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        for oom_record in oom_records:
673*da0073e9SAndroid Build Coastguard Worker            max_evt_id += 1
674*da0073e9SAndroid Build Coastguard Worker            fe = createFunctionEventForMemoryEvents(oom_record)
675*da0073e9SAndroid Build Coastguard Worker            all_function_events.append(fe)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        all_function_events.sort(
678*da0073e9SAndroid Build Coastguard Worker            key=lambda evt: [evt.time_range.start, -evt.time_range.end]
679*da0073e9SAndroid Build Coastguard Worker        )
680*da0073e9SAndroid Build Coastguard Worker        return all_function_events
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Workerclass record_function(_ContextDecorator):
684*da0073e9SAndroid Build Coastguard Worker    """Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
685*da0073e9SAndroid Build Coastguard Worker    Label will only appear if CPU activity tracing is enabled.
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker    It is useful when tracing the code profile.
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    Args:
690*da0073e9SAndroid Build Coastguard Worker        name (str): Label assigned to the block of code.
691*da0073e9SAndroid Build Coastguard Worker        node_id (int): ID of node, for distributed profiling. Unset in
692*da0073e9SAndroid Build Coastguard Worker        non-distributed cases.
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    Example:
695*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
696*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn((1, 1), requires_grad=True)
697*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.profiler.profile() as prof:
698*da0073e9SAndroid Build Coastguard Worker        ...     y = x ** 2
699*da0073e9SAndroid Build Coastguard Worker        ...     with torch.autograd.profiler.record_function("label-z"): # label the block
700*da0073e9SAndroid Build Coastguard Worker        ...         z = y ** 3
701*da0073e9SAndroid Build Coastguard Worker        ...     y.backward()
702*da0073e9SAndroid Build Coastguard Worker        ...
703*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT
704*da0073e9SAndroid Build Coastguard Worker        >>> # NOTE: some columns were removed for brevity
705*da0073e9SAndroid Build Coastguard Worker        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
706*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
707*da0073e9SAndroid Build Coastguard Worker        Name                                 Self CPU total %  CPU time avg     Number of Calls
708*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
709*da0073e9SAndroid Build Coastguard Worker        pow                                  60.77%           47.470us         3
710*da0073e9SAndroid Build Coastguard Worker        mul                                  21.73%           25.465us         2
711*da0073e9SAndroid Build Coastguard Worker        PowBackward0                         12.03%           121.891us        1
712*da0073e9SAndroid Build Coastguard Worker        torch::autograd::AccumulateGrad      2.70%            6.324us          1
713*da0073e9SAndroid Build Coastguard Worker        label-z                              2.13%            12.421us         1
714*da0073e9SAndroid Build Coastguard Worker        torch::autograd::GraphRoot           0.64%            1.503us          1
715*da0073e9SAndroid Build Coastguard Worker        -----------------------------------  ---------------  ---------------  ---------------
716*da0073e9SAndroid Build Coastguard Worker        Self CPU time total: 234.344us
717*da0073e9SAndroid Build Coastguard Worker        CUDA time total: 0.000us
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    """
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name: str, args: Optional[str] = None):
722*da0073e9SAndroid Build Coastguard Worker        self.name: str = name
723*da0073e9SAndroid Build Coastguard Worker        self.args: Optional[str] = args
724*da0073e9SAndroid Build Coastguard Worker        # Whether or not we should run record function's end callbacks when exiting.
725*da0073e9SAndroid Build Coastguard Worker        self.run_callbacks_on_exit: bool = True
726*da0073e9SAndroid Build Coastguard Worker        # TODO: TorchScript ignores standard type annotation here
727*da0073e9SAndroid Build Coastguard Worker        # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
728*da0073e9SAndroid Build Coastguard Worker        self.record = torch.jit.annotate(
729*da0073e9SAndroid Build Coastguard Worker            Optional["torch.classes.profiler._RecordFunction"], None
730*da0073e9SAndroid Build Coastguard Worker        )
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
733*da0073e9SAndroid Build Coastguard Worker        self.record = torch.ops.profiler._record_function_enter_new(
734*da0073e9SAndroid Build Coastguard Worker            self.name, self.args
735*da0073e9SAndroid Build Coastguard Worker        )
736*da0073e9SAndroid Build Coastguard Worker        return self
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
739*da0073e9SAndroid Build Coastguard Worker        if not self.run_callbacks_on_exit:
740*da0073e9SAndroid Build Coastguard Worker            return
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker        # Local variable is needed by TorchScript to refine Optional[T] to T
743*da0073e9SAndroid Build Coastguard Worker        record = self.record
744*da0073e9SAndroid Build Coastguard Worker        assert record is not None
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker        # TODO: Too slow with __torch_function__ handling enabled
747*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/76410
748*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting():
749*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
750*da0073e9SAndroid Build Coastguard Worker                torch.ops.profiler._record_function_exit._RecordFunction(record)
751*da0073e9SAndroid Build Coastguard Worker        else:
752*da0073e9SAndroid Build Coastguard Worker            torch.ops.profiler._record_function_exit(record)
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker    def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
755*da0073e9SAndroid Build Coastguard Worker        """Use for profiling async calls that return a future.
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        Calling this function will extend recording beyond this scope, until the future is
758*da0073e9SAndroid Build Coastguard Worker        satisfied. It is useful for profiling the end to end time of asynchronous calls.
759*da0073e9SAndroid Build Coastguard Worker        This function should only be called once to attach the callback onto the future, and
760*da0073e9SAndroid Build Coastguard Worker        will throw if called multiple times.
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        Args:
763*da0073e9SAndroid Build Coastguard Worker            fut: (torch._C.Future): future for which to schedule
764*da0073e9SAndroid Build Coastguard Worker            callback for.
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        Returns:
767*da0073e9SAndroid Build Coastguard Worker            A future that completes with the value of the passed in future when
768*da0073e9SAndroid Build Coastguard Worker            the profiling callbacks have ran.
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker        """
771*da0073e9SAndroid Build Coastguard Worker        # Throw if we have already attached a callback onto the future.
772*da0073e9SAndroid Build Coastguard Worker        if not self.run_callbacks_on_exit:
773*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker        # We are scheduling to run this RecordFunction's end callbacks when the
776*da0073e9SAndroid Build Coastguard Worker        # passed in future completes, so don't run end callbacks on exit.
777*da0073e9SAndroid Build Coastguard Worker        self.run_callbacks_on_exit = False
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        # Local variable is needed by TorchScript to refine Optional[T] to T
780*da0073e9SAndroid Build Coastguard Worker        record = self.record
781*da0073e9SAndroid Build Coastguard Worker        assert record is not None
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        # TODO: Too slow with __torch_function__ handling enabled
784*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/76410
785*da0073e9SAndroid Build Coastguard Worker        if not torch.jit.is_scripting():
786*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
787*da0073e9SAndroid Build Coastguard Worker                profiled_future = (
788*da0073e9SAndroid Build Coastguard Worker                    torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
789*da0073e9SAndroid Build Coastguard Worker                        record, fut
790*da0073e9SAndroid Build Coastguard Worker                    )
791*da0073e9SAndroid Build Coastguard Worker                )
792*da0073e9SAndroid Build Coastguard Worker        else:
793*da0073e9SAndroid Build Coastguard Worker            profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
794*da0073e9SAndroid Build Coastguard Worker                record, fut
795*da0073e9SAndroid Build Coastguard Worker            )
796*da0073e9SAndroid Build Coastguard Worker        return profiled_future
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Workerclass emit_itt:
800*da0073e9SAndroid Build Coastguard Worker    """Context manager that makes every autograd operation emit an ITT range.
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker    It is useful when running the program under Intel(R) VTune Profiler::
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        vtune <--vtune-flags> <regular command here>
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
807*da0073e9SAndroid Build Coastguard Worker    control the collection of trace data during its execution across different Intel tools.
808*da0073e9SAndroid Build Coastguard Worker    This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
809*da0073e9SAndroid Build Coastguard Worker    you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    .. warning:
812*da0073e9SAndroid Build Coastguard Worker        This context manager should not be called recursively, i.e. at most one
813*da0073e9SAndroid Build Coastguard Worker        instance should be enabled at any given time.
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker    Args:
816*da0073e9SAndroid Build Coastguard Worker        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
817*da0073e9SAndroid Build Coastguard Worker            Default: ``True``.
818*da0073e9SAndroid Build Coastguard Worker        record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
819*da0073e9SAndroid Build Coastguard Worker            each autograd op will append information about the sizes of Tensor arguments received
820*da0073e9SAndroid Build Coastguard Worker            by that op, in the following format:
821*da0073e9SAndroid Build Coastguard Worker            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
822*da0073e9SAndroid Build Coastguard Worker            Non-tensor arguments will be represented by ``[]``.
823*da0073e9SAndroid Build Coastguard Worker            Arguments will be listed in the order they are received by the backend op.
824*da0073e9SAndroid Build Coastguard Worker            Please note that this order may not match the order in which those arguments were passed
825*da0073e9SAndroid Build Coastguard Worker            on the Python side.  Also note that shape recording may increase the overhead of itt range creation.
826*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker    Example:
829*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("Undefined variables")
830*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
831*da0073e9SAndroid Build Coastguard Worker        >>> with torch.autograd.profiler.emit_itt():
832*da0073e9SAndroid Build Coastguard Worker        ...     model(x)
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker    """
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    def __init__(self, enabled=True, record_shapes=False):
837*da0073e9SAndroid Build Coastguard Worker        self.enabled = enabled
838*da0073e9SAndroid Build Coastguard Worker        self.entered = False
839*da0073e9SAndroid Build Coastguard Worker        self.record_shapes = record_shapes
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
842*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
843*da0073e9SAndroid Build Coastguard Worker            return
844*da0073e9SAndroid Build Coastguard Worker        if self.entered:
845*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("ITT annotation context manager is not reentrant")
846*da0073e9SAndroid Build Coastguard Worker        self.entered = True
847*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_start()
848*da0073e9SAndroid Build Coastguard Worker        _enable_profiler(
849*da0073e9SAndroid Build Coastguard Worker            ProfilerConfig(
850*da0073e9SAndroid Build Coastguard Worker                ProfilerState.ITT,
851*da0073e9SAndroid Build Coastguard Worker                self.record_shapes,
852*da0073e9SAndroid Build Coastguard Worker                False,
853*da0073e9SAndroid Build Coastguard Worker                False,
854*da0073e9SAndroid Build Coastguard Worker                False,
855*da0073e9SAndroid Build Coastguard Worker                False,
856*da0073e9SAndroid Build Coastguard Worker                _ExperimentalConfig(),
857*da0073e9SAndroid Build Coastguard Worker            ),
858*da0073e9SAndroid Build Coastguard Worker            set(),
859*da0073e9SAndroid Build Coastguard Worker        )
860*da0073e9SAndroid Build Coastguard Worker        return self
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
863*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
864*da0073e9SAndroid Build Coastguard Worker            return
865*da0073e9SAndroid Build Coastguard Worker        _disable_profiler()
866*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_stop()
867*da0073e9SAndroid Build Coastguard Worker        return False
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Workerclass emit_nvtx:
871*da0073e9SAndroid Build Coastguard Worker    """Context manager that makes every autograd operation emit an NVTX range.
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker    It is useful when running the program under nvprof::
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker    Unfortunately, there's no way to force nvprof to flush the data it collected
878*da0073e9SAndroid Build Coastguard Worker    to disk, so for CUDA profiling one has to use this context manager to annotate
879*da0073e9SAndroid Build Coastguard Worker    nvprof traces and wait for the process to exit before inspecting them.
880*da0073e9SAndroid Build Coastguard Worker    Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
881*da0073e9SAndroid Build Coastguard Worker    :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
882*da0073e9SAndroid Build Coastguard Worker    e.g. in Python REPL.
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    .. warning:
885*da0073e9SAndroid Build Coastguard Worker        This context manager should not be called recursively, i.e. at most one
886*da0073e9SAndroid Build Coastguard Worker        instance should be enabled at any given time.
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker    Args:
889*da0073e9SAndroid Build Coastguard Worker        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
890*da0073e9SAndroid Build Coastguard Worker            Default: ``True``.
891*da0073e9SAndroid Build Coastguard Worker        record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
892*da0073e9SAndroid Build Coastguard Worker            each autograd op will append information about the sizes of Tensor arguments received
893*da0073e9SAndroid Build Coastguard Worker            by that op, in the following format:
894*da0073e9SAndroid Build Coastguard Worker            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
895*da0073e9SAndroid Build Coastguard Worker            Non-tensor arguments will be represented by ``[]``.
896*da0073e9SAndroid Build Coastguard Worker            Arguments will be listed in the order they are received by the backend op.
897*da0073e9SAndroid Build Coastguard Worker            Please note that this order may not match the order in which those arguments were passed
898*da0073e9SAndroid Build Coastguard Worker            on the Python side.  Also note that shape recording may increase the overhead of nvtx range creation.
899*da0073e9SAndroid Build Coastguard Worker            Default: ``False``
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker    Example:
902*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("undefined variables")
903*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
904*da0073e9SAndroid Build Coastguard Worker        >>> with torch.cuda.profiler.profile():
905*da0073e9SAndroid Build Coastguard Worker        ...     model(x)  # Warmup CUDA memory allocator and profiler
906*da0073e9SAndroid Build Coastguard Worker        ...     with torch.autograd.profiler.emit_nvtx():
907*da0073e9SAndroid Build Coastguard Worker        ...         model(x)
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker    **Forward-backward correlation**
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker    When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
912*da0073e9SAndroid Build Coastguard Worker    correlating each backward-pass op with the corresponding forward-pass op can be difficult.
913*da0073e9SAndroid Build Coastguard Worker    To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
914*da0073e9SAndroid Build Coastguard Worker    generates.
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    During the forward pass, each function range is decorated with ``seq=<N>``.  ``seq`` is a running
917*da0073e9SAndroid Build Coastguard Worker    counter, incremented each time a new backward Function object is created and stashed for backward.
918*da0073e9SAndroid Build Coastguard Worker    Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
919*da0073e9SAndroid Build Coastguard Worker    if a backward Function object is created by this forward function,
920*da0073e9SAndroid Build Coastguard Worker    the backward object will receive sequence number N.
921*da0073e9SAndroid Build Coastguard Worker    During the backward pass, the top-level range wrapping each C++ backward Function's
922*da0073e9SAndroid Build Coastguard Worker    ``apply()`` call is decorated with ``stashed seq=<M>``.  ``M`` is the sequence number that
923*da0073e9SAndroid Build Coastguard Worker    the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``
924*da0073e9SAndroid Build Coastguard Worker    numbers in forward, you can track down which forward op created each backward Function.
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker    Any functions executed during the backward pass are also decorated with ``seq=<N>``.  During
927*da0073e9SAndroid Build Coastguard Worker    default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
928*da0073e9SAndroid Build Coastguard Worker    ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with
929*da0073e9SAndroid Build Coastguard Worker    backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
930*da0073e9SAndroid Build Coastguard Worker    objects with the earlier forward pass.
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker    **Double-backward**
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker    If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
935*da0073e9SAndroid Build Coastguard Worker    if you are setting up for a double-backward), each function's execution during backward
936*da0073e9SAndroid Build Coastguard Worker    is given a nonzero, useful ``seq=<N>``.  Those functions may themselves create Function objects
937*da0073e9SAndroid Build Coastguard Worker    to be executed later during double-backward, just as the original functions in the forward pass did.
938*da0073e9SAndroid Build Coastguard Worker    The relationship between backward and double-backward is conceptually the same as the relationship
939*da0073e9SAndroid Build Coastguard Worker    between forward and backward: The functions still emit current-sequence-number-tagged ranges,
940*da0073e9SAndroid Build Coastguard Worker    the Function objects they create still stash those sequence numbers, and during the eventual
941*da0073e9SAndroid Build Coastguard Worker    double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
942*da0073e9SAndroid Build Coastguard Worker    numbers, which can be compared to `seq` numbers from the backward pass.
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker    .. warning:
945*da0073e9SAndroid Build Coastguard Worker        The sequence number is thread-local, and some forward functions don't create an associated
946*da0073e9SAndroid Build Coastguard Worker        backward Function object (instead delegating that to sub-functions further down the call chain).
947*da0073e9SAndroid Build Coastguard Worker        For these reasons, the correspondence of stashed sequence numbers in
948*da0073e9SAndroid Build Coastguard Worker        backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
949*da0073e9SAndroid Build Coastguard Worker        not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully
950*da0073e9SAndroid Build Coastguard Worker        disambiguate which forward function created which
951*da0073e9SAndroid Build Coastguard Worker        backward Function object.  You may need to make a judgment based on analytic knowledge of what
952*da0073e9SAndroid Build Coastguard Worker        the expected correspondence should be.
953*da0073e9SAndroid Build Coastguard Worker    """
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    def __init__(self, enabled=True, record_shapes=False):
956*da0073e9SAndroid Build Coastguard Worker        self.enabled = enabled
957*da0073e9SAndroid Build Coastguard Worker        self.entered = False
958*da0073e9SAndroid Build Coastguard Worker        self.record_shapes = record_shapes
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
961*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
962*da0073e9SAndroid Build Coastguard Worker            return
963*da0073e9SAndroid Build Coastguard Worker        if self.entered:
964*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("NVTX annotation context manager is not reentrant")
965*da0073e9SAndroid Build Coastguard Worker        self.entered = True
966*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
967*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_start()
968*da0073e9SAndroid Build Coastguard Worker        _enable_profiler(
969*da0073e9SAndroid Build Coastguard Worker            ProfilerConfig(
970*da0073e9SAndroid Build Coastguard Worker                ProfilerState.NVTX,
971*da0073e9SAndroid Build Coastguard Worker                self.record_shapes,
972*da0073e9SAndroid Build Coastguard Worker                False,
973*da0073e9SAndroid Build Coastguard Worker                False,
974*da0073e9SAndroid Build Coastguard Worker                False,
975*da0073e9SAndroid Build Coastguard Worker                False,
976*da0073e9SAndroid Build Coastguard Worker                _ExperimentalConfig(),
977*da0073e9SAndroid Build Coastguard Worker            ),
978*da0073e9SAndroid Build Coastguard Worker            set(),
979*da0073e9SAndroid Build Coastguard Worker        )
980*da0073e9SAndroid Build Coastguard Worker        return self
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
983*da0073e9SAndroid Build Coastguard Worker        if not self.enabled:
984*da0073e9SAndroid Build Coastguard Worker            return
985*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
986*da0073e9SAndroid Build Coastguard Worker        _disable_profiler()
987*da0073e9SAndroid Build Coastguard Worker        _run_on_profiler_stop()
988*da0073e9SAndroid Build Coastguard Worker        return False
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Workerdef load_nvprof(path):
992*da0073e9SAndroid Build Coastguard Worker    """Open an nvprof trace file and parses autograd annotations.
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker    Args:
995*da0073e9SAndroid Build Coastguard Worker        path (str): path to nvprof trace
996*da0073e9SAndroid Build Coastguard Worker    """
997*da0073e9SAndroid Build Coastguard Worker    return EventList(parse_nvprof_trace(path))
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Workerclass EnforceUnique:
1001*da0073e9SAndroid Build Coastguard Worker    """Raises an error if a key is seen more than once."""
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
1004*da0073e9SAndroid Build Coastguard Worker        self.seen = set()
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker    def see(self, *key):
1007*da0073e9SAndroid Build Coastguard Worker        r"""
1008*da0073e9SAndroid Build Coastguard Worker        Observe a key and raise an error if it is seen multiple times.
1009*da0073e9SAndroid Build Coastguard Worker        """
1010*da0073e9SAndroid Build Coastguard Worker        if key in self.seen:
1011*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("duplicate key: " + str(key))
1012*da0073e9SAndroid Build Coastguard Worker        self.seen.add(key)
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Workerdef parse_nvprof_trace(path):
1016*da0073e9SAndroid Build Coastguard Worker    import sqlite3
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker    conn = sqlite3.connect(path)
1019*da0073e9SAndroid Build Coastguard Worker    conn.row_factory = sqlite3.Row
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker    # Parse strings table
1022*da0073e9SAndroid Build Coastguard Worker    strings = {}
1023*da0073e9SAndroid Build Coastguard Worker    for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
1024*da0073e9SAndroid Build Coastguard Worker        strings[r["id"]] = torch._C._demangle(r["value"])
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    # First, find all functions and create FunctionEvents for them
1027*da0073e9SAndroid Build Coastguard Worker    marker_query = """
1028*da0073e9SAndroid Build Coastguard Worker    SELECT
1029*da0073e9SAndroid Build Coastguard Worker        start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
1030*da0073e9SAndroid Build Coastguard Worker    FROM
1031*da0073e9SAndroid Build Coastguard Worker        CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1032*da0073e9SAndroid Build Coastguard Worker        ON start.id = end.id
1033*da0073e9SAndroid Build Coastguard Worker    WHERE
1034*da0073e9SAndroid Build Coastguard Worker        start.name != 0 AND end.name = 0
1035*da0073e9SAndroid Build Coastguard Worker    """
1036*da0073e9SAndroid Build Coastguard Worker    functions = []
1037*da0073e9SAndroid Build Coastguard Worker    functions_map = {}
1038*da0073e9SAndroid Build Coastguard Worker    unique = EnforceUnique()
1039*da0073e9SAndroid Build Coastguard Worker    for row in conn.execute(marker_query):
1040*da0073e9SAndroid Build Coastguard Worker        unique.see(row["marker_id"])
1041*da0073e9SAndroid Build Coastguard Worker        evt = FunctionEvent(
1042*da0073e9SAndroid Build Coastguard Worker            id=row["marker_id"],
1043*da0073e9SAndroid Build Coastguard Worker            node_id=0,  # missing a node_id when calling FunctionEvent. This is just to ensure
1044*da0073e9SAndroid Build Coastguard Worker            # that pytorch doesn't crash when creating a FunctionEvent() object
1045*da0073e9SAndroid Build Coastguard Worker            name=strings[row["name"]],
1046*da0073e9SAndroid Build Coastguard Worker            start_us=row["start_time"],
1047*da0073e9SAndroid Build Coastguard Worker            end_us=row["end_time"],
1048*da0073e9SAndroid Build Coastguard Worker            thread=0,
1049*da0073e9SAndroid Build Coastguard Worker        )  # TODO: find in sqlite database
1050*da0073e9SAndroid Build Coastguard Worker        functions.append(evt)
1051*da0073e9SAndroid Build Coastguard Worker        functions_map[evt.id] = evt
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    # Now, correlate all kernels with FunctionEvents
1054*da0073e9SAndroid Build Coastguard Worker    kernel_query = """
1055*da0073e9SAndroid Build Coastguard Worker    SELECT
1056*da0073e9SAndroid Build Coastguard Worker        start.id AS marker_id, start.name, start.timestamp, end.timestamp,
1057*da0073e9SAndroid Build Coastguard Worker        runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
1058*da0073e9SAndroid Build Coastguard Worker        kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
1059*da0073e9SAndroid Build Coastguard Worker    FROM
1060*da0073e9SAndroid Build Coastguard Worker        CUPTI_ACTIVITY_KIND_MARKER AS start
1061*da0073e9SAndroid Build Coastguard Worker        INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1062*da0073e9SAndroid Build Coastguard Worker            ON start.id = end.id
1063*da0073e9SAndroid Build Coastguard Worker        INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
1064*da0073e9SAndroid Build Coastguard Worker            ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
1065*da0073e9SAndroid Build Coastguard Worker        INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
1066*da0073e9SAndroid Build Coastguard Worker            ON kernel.correlationId = runtime.correlationId
1067*da0073e9SAndroid Build Coastguard Worker    """
1068*da0073e9SAndroid Build Coastguard Worker    unique = EnforceUnique()
1069*da0073e9SAndroid Build Coastguard Worker    for row in conn.execute(kernel_query):
1070*da0073e9SAndroid Build Coastguard Worker        unique.see(row["marker_id"], row["runtime_id"])
1071*da0073e9SAndroid Build Coastguard Worker        # 211 is cudaKernelLaunch for cuda >= 9.2
1072*da0073e9SAndroid Build Coastguard Worker        assert row["cbid"] == 211
1073*da0073e9SAndroid Build Coastguard Worker        evt = functions_map[row["marker_id"]]
1074*da0073e9SAndroid Build Coastguard Worker        evt.append_kernel(
1075*da0073e9SAndroid Build Coastguard Worker            row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1076*da0073e9SAndroid Build Coastguard Worker        )
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker    functions.sort(key=lambda evt: evt.time_range.start)
1079*da0073e9SAndroid Build Coastguard Worker    return functions
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Workerclass KinetoStepTracker:
1083*da0073e9SAndroid Build Coastguard Worker    """Provides an abstraction for incrementing the step count globally.
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker    Previously, we only had one place to mark that a step() has occurred
1086*da0073e9SAndroid Build Coastguard Worker    in the program via pytorch profiler step(). We will now add step hooks
1087*da0073e9SAndroid Build Coastguard Worker    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker    - This could mean programs that already call profiler.step() every
1090*da0073e9SAndroid Build Coastguard Worker      iteration can end up double incrementing step count.
1091*da0073e9SAndroid Build Coastguard Worker    - If a model uses multiple optimizers we can also have double or more
1092*da0073e9SAndroid Build Coastguard Worker      counting of the step.
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker    We fix this by adding a layer of abstraction before calling step()
1095*da0073e9SAndroid Build Coastguard Worker    to the kineto library. The idea is to maintain steps per requester in a dict:
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker    .. code-block::
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        {
1100*da0073e9SAndroid Build Coastguard Worker           "ProfilerStep": 100,  # triggered by profiler step() call
1101*da0073e9SAndroid Build Coastguard Worker           "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1102*da0073e9SAndroid Build Coastguard Worker           "Optimizer2Step": 100,
1103*da0073e9SAndroid Build Coastguard Worker        }
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker    To figure out the global step count just take the max of dict values (100).
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker    If one of the count increments the max will go up.
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker    .. code-block::
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        {
1112*da0073e9SAndroid Build Coastguard Worker           "ProfilerStep": 100,
1113*da0073e9SAndroid Build Coastguard Worker           "Optimizer1Step": 101,   # Optimizer1 got incremented first say
1114*da0073e9SAndroid Build Coastguard Worker           "Optimizer2Step": 100,
1115*da0073e9SAndroid Build Coastguard Worker        }
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    Then global step count is 101
1118*da0073e9SAndroid Build Coastguard Worker    We only call the kineto step() function when global count increments.
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1121*da0073e9SAndroid Build Coastguard Worker    for now. The result could be incorrect increments of the step count.
1122*da0073e9SAndroid Build Coastguard Worker    """
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    _current_step = 0
1125*da0073e9SAndroid Build Coastguard Worker    _step_dict: Dict[str, int] = defaultdict(int)
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    @classmethod
1128*da0073e9SAndroid Build Coastguard Worker    def init_step_count(cls, requester: str):
1129*da0073e9SAndroid Build Coastguard Worker        r"""
1130*da0073e9SAndroid Build Coastguard Worker        Initialize for a given requester.
1131*da0073e9SAndroid Build Coastguard Worker        """
1132*da0073e9SAndroid Build Coastguard Worker        cls._step_dict[requester] = cls._current_step
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    @classmethod
1135*da0073e9SAndroid Build Coastguard Worker    def erase_step_count(cls, requester: str) -> bool:
1136*da0073e9SAndroid Build Coastguard Worker        r"""
1137*da0073e9SAndroid Build Coastguard Worker        Remove a given requester.
1138*da0073e9SAndroid Build Coastguard Worker        """
1139*da0073e9SAndroid Build Coastguard Worker        return cls._step_dict.pop(requester, None) is not None
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker    @classmethod
1142*da0073e9SAndroid Build Coastguard Worker    def increment_step(cls, requester: str) -> int:
1143*da0073e9SAndroid Build Coastguard Worker        """Increments the step count for the requester.
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker        Additionally if the max over all step counts has incremented then
1146*da0073e9SAndroid Build Coastguard Worker        trigger the _kineto_step() returns global step count
1147*da0073e9SAndroid Build Coastguard Worker        """
1148*da0073e9SAndroid Build Coastguard Worker        if requester not in cls._step_dict:
1149*da0073e9SAndroid Build Coastguard Worker            cls.init_step_count(requester)
1150*da0073e9SAndroid Build Coastguard Worker        cls._step_dict[requester] += 1
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        new_step = max(cls._step_dict.values())
1153*da0073e9SAndroid Build Coastguard Worker        if new_step > cls._current_step:
1154*da0073e9SAndroid Build Coastguard Worker            delta = new_step - cls._current_step
1155*da0073e9SAndroid Build Coastguard Worker            if delta > 1:
1156*da0073e9SAndroid Build Coastguard Worker                warn(
1157*da0073e9SAndroid Build Coastguard Worker                    "Profiler step count has increased more than 1 - "
1158*da0073e9SAndroid Build Coastguard Worker                    f"current_step = {cls._current_step} step dict =  {cls._step_dict}"
1159*da0073e9SAndroid Build Coastguard Worker                )
1160*da0073e9SAndroid Build Coastguard Worker            for _ in range(0, delta):
1161*da0073e9SAndroid Build Coastguard Worker                _kineto_step()
1162*da0073e9SAndroid Build Coastguard Worker            cls._current_step = new_step
1163*da0073e9SAndroid Build Coastguard Worker        return cls._current_step
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker    @classmethod
1166*da0073e9SAndroid Build Coastguard Worker    def current_step(cls) -> int:
1167*da0073e9SAndroid Build Coastguard Worker        r"""
1168*da0073e9SAndroid Build Coastguard Worker        Get the latest step for any requester
1169*da0073e9SAndroid Build Coastguard Worker        """
1170*da0073e9SAndroid Build Coastguard Worker        return cls._current_step
1171