xref: /aosp_15_r20/external/pytorch/torch/autograd/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import defaultdict
3from dataclasses import dataclass
4from time import perf_counter_ns
5from typing import Any, Dict, Iterable, List, Optional
6from warnings import warn
7
8import torch
9import torch.cuda
10from torch._C import _get_privateuse1_backend_name
11from torch._C._profiler import _ExperimentalConfig
12from torch.autograd import (
13    _disable_profiler,
14    _enable_profiler,
15    _kineto_step,
16    _prepare_profiler,
17    _ProfilerResult,
18    _supported_activities,
19    _toggle_collection_dynamic,
20    DeviceType,
21    kineto_available,
22    ProfilerActivity,
23    ProfilerConfig,
24    ProfilerState,
25)
26from torch.autograd.profiler_util import (
27    _filter_name,
28    _filter_stack_entry,
29    _rewrite_name,
30    EventList,
31    FunctionEvent,
32    MEMORY_EVENT_NAME,
33    MemRecordsAcc,
34    OUT_OF_MEMORY_EVENT_NAME,
35)
36from torch.futures import Future
37
38
39__all__ = [
40    "profile",
41    "record_function",
42    "emit_itt",
43    "emit_nvtx",
44    "load_nvprof",
45    "EnforceUnique",
46    "parse_nvprof_trace",
47    "KinetoStepTracker",
48    "EventList",
49    "FunctionEvent",
50    "MemRecordsAcc",
51]
52
53try:
54    # Available in Python >= 3.2
55    from contextlib import ContextDecorator as _ContextDecorator
56except ImportError:
57    import functools
58
59    class _ContextDecorator:  # type: ignore[no-redef]
60        def __enter__(self):
61            raise NotImplementedError
62
63        def __exit__(self, exc_type, exc_val, exc_tb):
64            raise NotImplementedError
65
66        def __call__(self, func):
67            @functools.wraps(func)
68            def wrapped(*args, **kwargs):
69                with self:
70                    return func(*args, **kwargs)
71
72            return wrapped
73
74
75# global python state - whether profiler is currently enabled
76# useful for fast python checks to reduce latency
77_is_profiler_enabled: bool = False
78
79
80def _set_is_profiler_enabled(enable: bool):
81    global _is_profiler_enabled
82    _is_profiler_enabled = enable
83
84
85def _run_on_profiler_start():
86    _set_is_profiler_enabled(True)
87
88
89def _run_on_profiler_stop():
90    _set_is_profiler_enabled(False)
91
92
93@dataclass
94class _ProfilerStats:
95    "Profiler timing and stats used by developers to catch issues/regressions"
96    profiling_window_duration_sec: float = 0
97    number_of_events: int = 0
98    profiler_prepare_call_duration_us: int = 0
99    profiler_enable_call_duration_us: int = 0
100    profiler_disable_call_duration_us: int = 0
101    parse_kineto_call_duration_us: int = 0
102    function_events_build_tree_call_duration_us: int = 0
103
104
105class profile:
106    """Context manager that manages autograd profiler state and holds a summary of results.
107
108    Under the hood it just records events of functions being executed in C++ and
109    exposes those events to Python. You can wrap any code into it and it will
110    only report runtime of PyTorch functions.
111    Note: profiler is thread local and is automatically propagated into the async tasks
112
113    Args:
114        enabled (bool, optional): Setting this to False makes this context manager a no-op.
115
116        use_cuda (bool, optional): Enables timing of CUDA events as well
117            using the cudaEvent API. (will be deprecated)
118
119        use_device (str, optional): Enables timing of device events.
120            Adds approximately 4us of overhead to each tensor operation when use cuda.
121            The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'.
122
123        record_shapes (bool, optional): If shapes recording is set, information
124            about input dimensions will be collected. This allows one to see which
125            dimensions have been used under the hood and further group by them
126            using prof.key_averages(group_by_input_shape=True). Please note that
127            shape recording might skew your profiling data. It is recommended to
128            use separate runs with and without shape recording to validate the timing.
129            Most likely the skew will be negligible for bottom most events (in a case
130            of nested function calls). But for higher level functions the total
131            self cpu time might be artificially increased because of the shape
132            collection.
133
134        with_flops (bool, optional): If with_flops is set, the profiler will estimate
135            the FLOPs (floating point operations) value using the operator's input shape.
136            This allows one to estimate the hardware performance. Currently,
137            this option only works for the matrix multiplication and 2D convolution operators.
138
139        profile_memory (bool, optional): track tensor memory allocation/deallocation.
140
141        with_stack (bool, optional): record source information (file and line number) for the ops.
142
143        with_modules (bool): record module hierarchy (including function names)
144            corresponding to the callstack of the op. e.g. If module A's forward call's
145            module B's forward which contains an aten::add op,
146            then aten::add's module hierarchy is A.B
147            Note that this support exist, at the moment, only for TorchScript models
148            and not eager mode models.
149
150        use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
151
152        use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
153            ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
154
155        experimental_config (_ExperimentalConfig) : A set of experimental options
156            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
157
158        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
159
160
161    .. warning:
162        Enabling memory profiling or source attribution incurs additional profiler
163        overhead
164
165    .. warning:
166        This context managers should not be called recursively, i.e. no nested
167        instances are allowed
168
169    .. warning:
170        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
171        one cannot use the profiler with ``use_device = 'cuda'`` to benchmark
172        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
173        please use ``use_device = None`` or ``num_workers = 0``.
174
175    Example:
176        >>> # xdoctest: +SKIP
177        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
178        >>> x = torch.randn((1, 1), requires_grad=True)
179        >>> with torch.autograd.profiler.profile() as prof:
180        >>>     for _ in range(100):  # any normal python code, really!
181        >>>         y = x ** 2
182        >>>         y.backward()
183        >>> # NOTE: some columns were removed for brevity
184        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
185        -----------------------------------  ---------------  ---------------  ---------------
186        Name                                 Self CPU total   CPU time avg     Number of Calls
187        -----------------------------------  ---------------  ---------------  ---------------
188        mul                                  32.048ms         32.048ms         200
189        pow                                  27.041ms         27.041ms         200
190        PowBackward0                         9.727ms          55.483ms         100
191        torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
192        torch::autograd::GraphRoot           691.816us        691.816us        100
193        -----------------------------------  ---------------  ---------------  ---------------
194
195    """
196
197    def __init__(
198        self,
199        enabled=True,
200        *,
201        use_cuda=False,  # Deprecated
202        use_device=None,
203        record_shapes=False,
204        with_flops=False,
205        profile_memory=False,
206        with_stack=False,
207        with_modules=False,
208        use_kineto=False,
209        use_cpu=True,
210        experimental_config=None,
211        acc_events=False,
212    ):
213        self.enabled: bool = enabled
214        if not self.enabled:
215            return
216        self.use_cuda = use_cuda
217        if self.use_cuda:
218            warn(
219                "The attribute `use_cuda` will be deprecated soon, "
220                "please use ``use_device = 'cuda'`` instead.",
221                FutureWarning,
222                stacklevel=2,
223            )
224            self.use_device: Optional[str] = "cuda"
225        else:
226            self.use_device = use_device
227        # TODO Consider changing _function_events into data structure with size cap
228        self._function_events: Optional[EventList] = None
229        self._old_function_events: Optional[EventList] = None
230        # Function event processing is done lazily
231        self._needs_processing = False
232        self.entered = False
233        self.record_shapes = record_shapes
234        self.with_flops = with_flops
235        self.record_shapes |= self.with_flops
236        self.profile_memory = profile_memory
237        self.with_stack = with_stack
238        self.with_modules = with_modules
239        self.use_cpu = use_cpu
240        self.acc_events = acc_events
241        if experimental_config is None:
242            experimental_config = _ExperimentalConfig()
243        self.experimental_config = experimental_config
244        self.kineto_results: Optional[_ProfilerResult] = None
245        self.profiling_start_time_ns = 0
246        self.profiling_end_time_ns = 0
247        self._stats = _ProfilerStats()
248
249        if not self.use_cpu:
250            assert (
251                use_kineto
252            ), "Device-only events supported only with Kineto (use_kineto=True)"
253
254        if self.use_device is not None:
255            VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"]
256            if _get_privateuse1_backend_name() != "privateuseone":
257                VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name())
258            if self.use_device not in VALID_DEVICE_OPTIONS:
259                warn(f"The {self.use_device} is not a valid device option.")
260                self.use_device = None
261
262            if self.use_device == "cuda" and not torch.cuda.is_available():
263                warn("CUDA is not available, disabling CUDA profiling")
264                self.use_cuda = False
265                self.use_device = None
266
267            if self.use_device == "xpu" and not torch.xpu.is_available():
268                warn("XPU is not available, disabling XPU profiling")
269                self.use_device = None
270
271        self.kineto_activities = set()
272        if self.use_cpu:
273            self.kineto_activities.add(ProfilerActivity.CPU)
274
275        self.profiler_kind = ProfilerState.KINETO
276        if self.use_device == "cuda":
277            if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
278                assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
279                self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
280            else:
281                self.kineto_activities.add(ProfilerActivity.CUDA)
282        elif self.use_device == "xpu":
283            assert (
284                use_kineto and ProfilerActivity.XPU in _supported_activities()
285            ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
286            self.kineto_activities.add(ProfilerActivity.XPU)
287        elif self.use_device == "mtia":
288            assert (
289                use_kineto and ProfilerActivity.MTIA in _supported_activities()
290            ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices."
291            self.kineto_activities.add(ProfilerActivity.MTIA)
292        elif self.use_device is not None and self.use_device != "privateuseone":
293            if (
294                not use_kineto
295                or ProfilerActivity.PrivateUse1 not in _supported_activities()
296            ):
297                assert (
298                    self.use_cpu
299                ), "Legacy custombackend profiling requires use_cpu=True"
300                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
301            else:
302                self.kineto_activities.add(ProfilerActivity.PrivateUse1)
303
304        assert (
305            len(self.kineto_activities) > 0
306        ), "No activities specified for the profiler"
307
308    def config(self):
309        return ProfilerConfig(
310            self.profiler_kind,
311            self.record_shapes,
312            self.profile_memory,
313            self.with_stack,
314            self.with_flops,
315            self.with_modules,
316            self.experimental_config,
317        )
318
319    def __enter__(self):
320        if not self.enabled:
321            return
322        if self.entered:
323            raise RuntimeError("Profiler context manager is not reentrant")
324        self._prepare_trace()
325        self._start_trace()
326        return self
327
328    def _prepare_trace(self):
329        self.entered = True
330        t0 = perf_counter_ns()
331        _prepare_profiler(self.config(), self.kineto_activities)
332        t1 = perf_counter_ns()
333        self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000)
334
335    def _start_trace(self):
336        self.entered = True
337        _run_on_profiler_start()
338        t0 = perf_counter_ns()
339        _enable_profiler(self.config(), self.kineto_activities)
340        t1 = perf_counter_ns()
341        self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000)
342        self.profiling_start_time_ns = t1
343
344    def __exit__(self, exc_type, exc_val, exc_tb):
345        if not self.enabled:
346            return
347        if self.use_device and hasattr(torch, self.use_device):
348            device_module = getattr(torch, self.use_device)
349            if hasattr(device_module, "synchronize"):
350                device_module.synchronize()
351
352        if self._function_events and self.acc_events:
353            self._old_function_events = self._function_events
354        self._function_events = None
355        self._needs_processing = True
356
357        t0 = perf_counter_ns()
358
359        self.kineto_results = _disable_profiler()
360        t1 = perf_counter_ns()
361        self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000)
362        self.profiling_end_time_ns = t0
363
364        _run_on_profiler_stop()
365
366        self._stats.profiling_window_duration_sec = (
367            (self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9
368        )
369
370        # If we plan to accumulate events we should post process the function events
371        # right away to retain the state across mulitple start/stop calls
372        if self.acc_events:
373            self._ensure_function_events()
374        return False
375
376    def __repr__(self):
377        if self._needs_processing:
378            self._ensure_function_events()
379        if self._function_events is None:
380            return "<unfinished torch.autograd.profile>"
381        return repr(self._function_events)
382
383    def __str__(self):
384        if self._needs_processing:
385            self._ensure_function_events()
386        if self._function_events is None:
387            return "<unfinished torch.autograd.profile>"
388        return str(self._function_events)
389
390    def _ensure_function_events(self):
391        """Process function events lazily if required"""
392        if self._function_events is not None:
393            return
394        self._needs_processing = False
395
396        t0 = perf_counter_ns()
397        parsed_results = []
398        if self.kineto_results:
399            parsed_results = self._parse_kineto_results(self.kineto_results)
400        t1 = perf_counter_ns()
401        self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000)
402
403        self._function_events = EventList(
404            parsed_results,
405            use_device=self.use_device,
406            profile_memory=self.profile_memory,
407            with_flops=self.with_flops,
408        )
409        t0 = perf_counter_ns()
410        self._function_events._build_tree()
411        t1 = perf_counter_ns()
412        self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000)
413        self._stats.number_of_events = len(self._function_events)
414
415        if self._old_function_events and self.acc_events:
416            for evt in self._old_function_events:
417                self._function_events.append(evt)
418            self._old_function_events = None
419
420        if self._function_events is None:
421            raise RuntimeError("Profiler didn't finish running")
422
423    @property
424    def function_events(self):
425        if self._function_events is None or self._needs_processing:
426            self._ensure_function_events()
427        return self._function_events
428
429    def table(
430        self,
431        sort_by=None,
432        row_limit=100,
433        max_src_column_width=75,
434        max_name_column_width=55,
435        max_shapes_column_width=80,
436        header=None,
437        top_level_events_only=False,
438    ):
439        self._ensure_function_events()
440        assert self._function_events is not None
441        return self._function_events.table(
442            sort_by=sort_by,
443            row_limit=row_limit,
444            max_src_column_width=max_src_column_width,
445            max_name_column_width=max_name_column_width,
446            max_shapes_column_width=max_shapes_column_width,
447            header=header,
448            top_level_events_only=top_level_events_only,
449        )
450
451    table.__doc__ = EventList.table.__doc__
452
453    def export_chrome_trace(self, path):
454        """
455        Exports the collected trace in Chrome JSON format. If kineto is enabled, only
456        last cycle in schedule is exported.
457        """
458        if kineto_available():
459            self.kineto_results.save(path)  # type: ignore[union-attr]
460        else:
461            self._ensure_function_events()
462            return self._function_events.export_chrome_trace(path)  # type: ignore[union-attr]
463
464    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
465
466    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
467        self._ensure_function_events()
468        assert self._function_events is not None, "Expected profiling results"
469        assert self.with_stack, "export_stacks() requires with_stack=True"
470        return self._function_events.export_stacks(path, metric)
471
472    def toggle_collection_dynamic(
473        self, enabled: bool, activities: Iterable[ProfilerActivity]
474    ):
475        """
476        Toggles the collection of activities for the current profiler instance.
477        """
478        return _toggle_collection_dynamic(enabled, set(activities))
479
480    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
481        self._ensure_function_events()
482        assert self._function_events is not None, "Expected profiling results"
483        return self._function_events.key_averages(
484            group_by_input_shape, group_by_stack_n
485        )
486
487    key_averages.__doc__ = EventList.key_averages.__doc__
488
489    def total_average(self):
490        self._ensure_function_events()
491        assert self._function_events is not None, "Expected profiling results"
492        return self._function_events.total_average()
493
494    total_average.__doc__ = EventList.total_average.__doc__
495
496    @property
497    def self_cpu_time_total(self):
498        """Returns total time spent on CPU.
499
500        The total time is a sum of all self times across all the events.
501        """
502        self._ensure_function_events()
503        assert self._function_events is not None
504        return self._function_events.self_cpu_time_total
505
506    def _parse_kineto_results(self, result: _ProfilerResult):
507        # result.events() has most of the events - PyTorch op-level and device-level events
508
509        trace_start_ns = result.trace_start_ns()
510        mem_records = [
511            [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
512        ]
513        oom_records = [
514            evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
515        ]
516        mem_records_acc = MemRecordsAcc(mem_records)
517
518        def _cpu_memory_usage(mem_record):
519            return (
520                mem_record.nbytes()
521                if mem_record.device_type()
522                in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
523                else 0
524            )
525
526        def _device_memory_usage(mem_record):
527            return (
528                mem_record.nbytes()
529                if mem_record.device_type()
530                in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
531                else 0
532            )
533
534        # Create and return FunctionEvent list, which contains all function events
535        # Here 2 function events are created:
536        # all_function_events contains all events associated with each kineto event from result
537        all_function_events = []
538        # frontend_function_events contains the events in aten or torch frontend level,
539        # whose correlation id is 0
540        frontend_function_events = []
541        device_corr_map: Dict[int, List[FunctionEvent]] = {}
542        max_evt_id = 0
543        for kineto_event in result.events():
544            if _filter_name(kineto_event.name()):
545                continue
546            rel_start_ns = kineto_event.start_ns() - trace_start_ns
547            rel_end_ns = kineto_event.end_ns() - trace_start_ns
548            abs_end_ns = kineto_event.end_ns()
549
550            cpu_memory_usage = 0
551            device_memory_usage = 0
552            if kineto_event.device_type() == DeviceType.CPU:
553                # find the corresponding memory allocation events
554                for mem_record in mem_records_acc.in_interval(
555                    kineto_event.start_ns() / 1000, abs_end_ns / 1000
556                ):
557                    cpu_memory_usage += _cpu_memory_usage(mem_record[0])
558                    device_memory_usage += _device_memory_usage(mem_record[0])
559                    mem_record[1] = True
560
561            is_async = kineto_event.is_async() or (
562                kineto_event.start_thread_id() != kineto_event.end_thread_id()
563            )
564
565            fe = FunctionEvent(
566                id=kineto_event.correlation_id(),
567                name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
568                trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
569                thread=kineto_event.start_thread_id(),
570                start_us=rel_start_ns / 1000,
571                end_us=rel_end_ns / 1000,
572                fwd_thread=kineto_event.fwd_thread_id(),
573                input_shapes=kineto_event.shapes(),
574                concrete_inputs=kineto_event.concrete_inputs(),
575                kwinputs=kineto_event.kwinputs(),
576                stack=[
577                    entry
578                    for entry in kineto_event.stack()
579                    if _filter_stack_entry(entry)
580                ],
581                scope=kineto_event.scope(),
582                use_device=self.use_device,
583                cpu_memory_usage=cpu_memory_usage,
584                device_memory_usage=device_memory_usage,
585                is_async=is_async,
586                sequence_nr=kineto_event.sequence_nr(),
587                device_type=kineto_event.device_type(),
588                device_index=kineto_event.device_index(),
589                device_resource_id=kineto_event.device_resource_id(),
590                flops=kineto_event.flops(),
591                is_user_annotation=kineto_event.is_user_annotation(),
592            )
593            max_evt_id = max(max_evt_id, fe.id)
594            if fe.device_type == DeviceType.CPU and not fe.is_async:
595                if self.use_device == "privateuseone":
596                    privateuse1_time = kineto_event.privateuse1_elapsed_us()
597                    if privateuse1_time > 0:
598                        fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
599                        fe.is_legacy = True
600                elif self.use_device == "cuda":
601                    # Check if we have CUDA time as a fallback
602                    cuda_time = kineto_event.cuda_elapsed_us()
603                    if cuda_time > 0:
604                        fe.append_kernel(fe.name, fe.device_index, cuda_time)
605                        fe.is_legacy = True
606            all_function_events.append(fe)
607            corr_id = kineto_event.linked_correlation_id()
608            if corr_id > 0:
609                if corr_id not in device_corr_map:
610                    device_corr_map[corr_id] = []
611                device_corr_map[corr_id].append(fe)
612            elif corr_id == 0:
613                frontend_function_events.append(fe)
614            else:
615                raise RuntimeError(
616                    f"Got negative correlation id {corr_id} in profiler post processing"
617                )
618
619        # associate device kernels and device runtime (CPU) with CPU events
620        for fe in frontend_function_events:
621            if (
622                fe.device_type == DeviceType.CPU
623                and not fe.is_async
624                and fe.id in device_corr_map
625            ):
626                for f_evt in device_corr_map[fe.id]:
627                    if (
628                        f_evt.device_type == DeviceType.CUDA
629                        or f_evt.device_type == DeviceType.PrivateUse1
630                    ):
631                        fe.append_kernel(
632                            f_evt.name,
633                            f_evt.device_index,
634                            f_evt.time_range.end - f_evt.time_range.start,
635                        )
636                    elif f_evt.device_type == DeviceType.CPU:
637                        # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated
638                        # with the 'thread' of the corresponding linked PyTorch event to properly track
639                        # parents and children
640                        f_evt.thread = fe.thread
641
642        def createFunctionEventForMemoryEvents(evt):
643            rel_start_ns = evt.start_ns() - trace_start_ns
644            fe = FunctionEvent(
645                id=max_evt_id,
646                name=evt.name(),
647                trace_name=None,  # not outputting in the trace
648                thread=evt.start_thread_id(),
649                start_us=rel_start_ns / 1000,
650                end_us=rel_start_ns / 1000,  # no duration
651                fwd_thread=evt.start_thread_id(),
652                input_shapes=[],
653                stack=[],
654                scope=0,  # RecordScope::FUNCTION
655                use_device=self.use_device,
656                cpu_memory_usage=_cpu_memory_usage(evt),
657                device_memory_usage=_device_memory_usage(evt),
658                is_async=False,
659                sequence_nr=-1,
660                device_type=DeviceType.CPU,
661                device_index=0,
662            )
663            return fe
664
665        # output top-level memory events
666        for mem_record in mem_records:
667            if not mem_record[1]:
668                max_evt_id += 1
669                fe = createFunctionEventForMemoryEvents(mem_record[0])
670                all_function_events.append(fe)
671
672        for oom_record in oom_records:
673            max_evt_id += 1
674            fe = createFunctionEventForMemoryEvents(oom_record)
675            all_function_events.append(fe)
676
677        all_function_events.sort(
678            key=lambda evt: [evt.time_range.start, -evt.time_range.end]
679        )
680        return all_function_events
681
682
683class record_function(_ContextDecorator):
684    """Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
685    Label will only appear if CPU activity tracing is enabled.
686
687    It is useful when tracing the code profile.
688
689    Args:
690        name (str): Label assigned to the block of code.
691        node_id (int): ID of node, for distributed profiling. Unset in
692        non-distributed cases.
693
694    Example:
695        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
696        >>> x = torch.randn((1, 1), requires_grad=True)
697        >>> with torch.autograd.profiler.profile() as prof:
698        ...     y = x ** 2
699        ...     with torch.autograd.profiler.record_function("label-z"): # label the block
700        ...         z = y ** 3
701        ...     y.backward()
702        ...
703        >>> # xdoctest: +IGNORE_WANT
704        >>> # NOTE: some columns were removed for brevity
705        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
706        -----------------------------------  ---------------  ---------------  ---------------
707        Name                                 Self CPU total %  CPU time avg     Number of Calls
708        -----------------------------------  ---------------  ---------------  ---------------
709        pow                                  60.77%           47.470us         3
710        mul                                  21.73%           25.465us         2
711        PowBackward0                         12.03%           121.891us        1
712        torch::autograd::AccumulateGrad      2.70%            6.324us          1
713        label-z                              2.13%            12.421us         1
714        torch::autograd::GraphRoot           0.64%            1.503us          1
715        -----------------------------------  ---------------  ---------------  ---------------
716        Self CPU time total: 234.344us
717        CUDA time total: 0.000us
718
719    """
720
721    def __init__(self, name: str, args: Optional[str] = None):
722        self.name: str = name
723        self.args: Optional[str] = args
724        # Whether or not we should run record function's end callbacks when exiting.
725        self.run_callbacks_on_exit: bool = True
726        # TODO: TorchScript ignores standard type annotation here
727        # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
728        self.record = torch.jit.annotate(
729            Optional["torch.classes.profiler._RecordFunction"], None
730        )
731
732    def __enter__(self):
733        self.record = torch.ops.profiler._record_function_enter_new(
734            self.name, self.args
735        )
736        return self
737
738    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
739        if not self.run_callbacks_on_exit:
740            return
741
742        # Local variable is needed by TorchScript to refine Optional[T] to T
743        record = self.record
744        assert record is not None
745
746        # TODO: Too slow with __torch_function__ handling enabled
747        # See https://github.com/pytorch/pytorch/issues/76410
748        if not torch.jit.is_scripting():
749            with torch._C.DisableTorchFunctionSubclass():
750                torch.ops.profiler._record_function_exit._RecordFunction(record)
751        else:
752            torch.ops.profiler._record_function_exit(record)
753
754    def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
755        """Use for profiling async calls that return a future.
756
757        Calling this function will extend recording beyond this scope, until the future is
758        satisfied. It is useful for profiling the end to end time of asynchronous calls.
759        This function should only be called once to attach the callback onto the future, and
760        will throw if called multiple times.
761
762        Args:
763            fut: (torch._C.Future): future for which to schedule
764            callback for.
765
766        Returns:
767            A future that completes with the value of the passed in future when
768            the profiling callbacks have ran.
769
770        """
771        # Throw if we have already attached a callback onto the future.
772        if not self.run_callbacks_on_exit:
773            raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
774
775        # We are scheduling to run this RecordFunction's end callbacks when the
776        # passed in future completes, so don't run end callbacks on exit.
777        self.run_callbacks_on_exit = False
778
779        # Local variable is needed by TorchScript to refine Optional[T] to T
780        record = self.record
781        assert record is not None
782
783        # TODO: Too slow with __torch_function__ handling enabled
784        # See https://github.com/pytorch/pytorch/issues/76410
785        if not torch.jit.is_scripting():
786            with torch._C.DisableTorchFunctionSubclass():
787                profiled_future = (
788                    torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
789                        record, fut
790                    )
791                )
792        else:
793            profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
794                record, fut
795            )
796        return profiled_future
797
798
799class emit_itt:
800    """Context manager that makes every autograd operation emit an ITT range.
801
802    It is useful when running the program under Intel(R) VTune Profiler::
803
804        vtune <--vtune-flags> <regular command here>
805
806    The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
807    control the collection of trace data during its execution across different Intel tools.
808    This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
809    you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
810
811    .. warning:
812        This context manager should not be called recursively, i.e. at most one
813        instance should be enabled at any given time.
814
815    Args:
816        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
817            Default: ``True``.
818        record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
819            each autograd op will append information about the sizes of Tensor arguments received
820            by that op, in the following format:
821            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
822            Non-tensor arguments will be represented by ``[]``.
823            Arguments will be listed in the order they are received by the backend op.
824            Please note that this order may not match the order in which those arguments were passed
825            on the Python side.  Also note that shape recording may increase the overhead of itt range creation.
826            Default: ``False``
827
828    Example:
829        >>> # xdoctest: +SKIP("Undefined variables")
830        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
831        >>> with torch.autograd.profiler.emit_itt():
832        ...     model(x)
833
834    """
835
836    def __init__(self, enabled=True, record_shapes=False):
837        self.enabled = enabled
838        self.entered = False
839        self.record_shapes = record_shapes
840
841    def __enter__(self):
842        if not self.enabled:
843            return
844        if self.entered:
845            raise RuntimeError("ITT annotation context manager is not reentrant")
846        self.entered = True
847        _run_on_profiler_start()
848        _enable_profiler(
849            ProfilerConfig(
850                ProfilerState.ITT,
851                self.record_shapes,
852                False,
853                False,
854                False,
855                False,
856                _ExperimentalConfig(),
857            ),
858            set(),
859        )
860        return self
861
862    def __exit__(self, exc_type, exc_val, exc_tb):
863        if not self.enabled:
864            return
865        _disable_profiler()
866        _run_on_profiler_stop()
867        return False
868
869
870class emit_nvtx:
871    """Context manager that makes every autograd operation emit an NVTX range.
872
873    It is useful when running the program under nvprof::
874
875        nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
876
877    Unfortunately, there's no way to force nvprof to flush the data it collected
878    to disk, so for CUDA profiling one has to use this context manager to annotate
879    nvprof traces and wait for the process to exit before inspecting them.
880    Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
881    :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
882    e.g. in Python REPL.
883
884    .. warning:
885        This context manager should not be called recursively, i.e. at most one
886        instance should be enabled at any given time.
887
888    Args:
889        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
890            Default: ``True``.
891        record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
892            each autograd op will append information about the sizes of Tensor arguments received
893            by that op, in the following format:
894            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
895            Non-tensor arguments will be represented by ``[]``.
896            Arguments will be listed in the order they are received by the backend op.
897            Please note that this order may not match the order in which those arguments were passed
898            on the Python side.  Also note that shape recording may increase the overhead of nvtx range creation.
899            Default: ``False``
900
901    Example:
902        >>> # xdoctest: +SKIP("undefined variables")
903        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
904        >>> with torch.cuda.profiler.profile():
905        ...     model(x)  # Warmup CUDA memory allocator and profiler
906        ...     with torch.autograd.profiler.emit_nvtx():
907        ...         model(x)
908
909    **Forward-backward correlation**
910
911    When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
912    correlating each backward-pass op with the corresponding forward-pass op can be difficult.
913    To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
914    generates.
915
916    During the forward pass, each function range is decorated with ``seq=<N>``.  ``seq`` is a running
917    counter, incremented each time a new backward Function object is created and stashed for backward.
918    Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
919    if a backward Function object is created by this forward function,
920    the backward object will receive sequence number N.
921    During the backward pass, the top-level range wrapping each C++ backward Function's
922    ``apply()`` call is decorated with ``stashed seq=<M>``.  ``M`` is the sequence number that
923    the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``
924    numbers in forward, you can track down which forward op created each backward Function.
925
926    Any functions executed during the backward pass are also decorated with ``seq=<N>``.  During
927    default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
928    ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with
929    backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
930    objects with the earlier forward pass.
931
932    **Double-backward**
933
934    If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
935    if you are setting up for a double-backward), each function's execution during backward
936    is given a nonzero, useful ``seq=<N>``.  Those functions may themselves create Function objects
937    to be executed later during double-backward, just as the original functions in the forward pass did.
938    The relationship between backward and double-backward is conceptually the same as the relationship
939    between forward and backward: The functions still emit current-sequence-number-tagged ranges,
940    the Function objects they create still stash those sequence numbers, and during the eventual
941    double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
942    numbers, which can be compared to `seq` numbers from the backward pass.
943
944    .. warning:
945        The sequence number is thread-local, and some forward functions don't create an associated
946        backward Function object (instead delegating that to sub-functions further down the call chain).
947        For these reasons, the correspondence of stashed sequence numbers in
948        backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
949        not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully
950        disambiguate which forward function created which
951        backward Function object.  You may need to make a judgment based on analytic knowledge of what
952        the expected correspondence should be.
953    """
954
955    def __init__(self, enabled=True, record_shapes=False):
956        self.enabled = enabled
957        self.entered = False
958        self.record_shapes = record_shapes
959
960    def __enter__(self):
961        if not self.enabled:
962            return
963        if self.entered:
964            raise RuntimeError("NVTX annotation context manager is not reentrant")
965        self.entered = True
966        torch.cuda.synchronize()
967        _run_on_profiler_start()
968        _enable_profiler(
969            ProfilerConfig(
970                ProfilerState.NVTX,
971                self.record_shapes,
972                False,
973                False,
974                False,
975                False,
976                _ExperimentalConfig(),
977            ),
978            set(),
979        )
980        return self
981
982    def __exit__(self, exc_type, exc_val, exc_tb):
983        if not self.enabled:
984            return
985        torch.cuda.synchronize()
986        _disable_profiler()
987        _run_on_profiler_stop()
988        return False
989
990
991def load_nvprof(path):
992    """Open an nvprof trace file and parses autograd annotations.
993
994    Args:
995        path (str): path to nvprof trace
996    """
997    return EventList(parse_nvprof_trace(path))
998
999
1000class EnforceUnique:
1001    """Raises an error if a key is seen more than once."""
1002
1003    def __init__(self):
1004        self.seen = set()
1005
1006    def see(self, *key):
1007        r"""
1008        Observe a key and raise an error if it is seen multiple times.
1009        """
1010        if key in self.seen:
1011            raise RuntimeError("duplicate key: " + str(key))
1012        self.seen.add(key)
1013
1014
1015def parse_nvprof_trace(path):
1016    import sqlite3
1017
1018    conn = sqlite3.connect(path)
1019    conn.row_factory = sqlite3.Row
1020
1021    # Parse strings table
1022    strings = {}
1023    for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
1024        strings[r["id"]] = torch._C._demangle(r["value"])
1025
1026    # First, find all functions and create FunctionEvents for them
1027    marker_query = """
1028    SELECT
1029        start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
1030    FROM
1031        CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1032        ON start.id = end.id
1033    WHERE
1034        start.name != 0 AND end.name = 0
1035    """
1036    functions = []
1037    functions_map = {}
1038    unique = EnforceUnique()
1039    for row in conn.execute(marker_query):
1040        unique.see(row["marker_id"])
1041        evt = FunctionEvent(
1042            id=row["marker_id"],
1043            node_id=0,  # missing a node_id when calling FunctionEvent. This is just to ensure
1044            # that pytorch doesn't crash when creating a FunctionEvent() object
1045            name=strings[row["name"]],
1046            start_us=row["start_time"],
1047            end_us=row["end_time"],
1048            thread=0,
1049        )  # TODO: find in sqlite database
1050        functions.append(evt)
1051        functions_map[evt.id] = evt
1052
1053    # Now, correlate all kernels with FunctionEvents
1054    kernel_query = """
1055    SELECT
1056        start.id AS marker_id, start.name, start.timestamp, end.timestamp,
1057        runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
1058        kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
1059    FROM
1060        CUPTI_ACTIVITY_KIND_MARKER AS start
1061        INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1062            ON start.id = end.id
1063        INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
1064            ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
1065        INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
1066            ON kernel.correlationId = runtime.correlationId
1067    """
1068    unique = EnforceUnique()
1069    for row in conn.execute(kernel_query):
1070        unique.see(row["marker_id"], row["runtime_id"])
1071        # 211 is cudaKernelLaunch for cuda >= 9.2
1072        assert row["cbid"] == 211
1073        evt = functions_map[row["marker_id"]]
1074        evt.append_kernel(
1075            row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1076        )
1077
1078    functions.sort(key=lambda evt: evt.time_range.start)
1079    return functions
1080
1081
1082class KinetoStepTracker:
1083    """Provides an abstraction for incrementing the step count globally.
1084
1085    Previously, we only had one place to mark that a step() has occurred
1086    in the program via pytorch profiler step(). We will now add step hooks
1087    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1088
1089    - This could mean programs that already call profiler.step() every
1090      iteration can end up double incrementing step count.
1091    - If a model uses multiple optimizers we can also have double or more
1092      counting of the step.
1093
1094    We fix this by adding a layer of abstraction before calling step()
1095    to the kineto library. The idea is to maintain steps per requester in a dict:
1096
1097    .. code-block::
1098
1099        {
1100           "ProfilerStep": 100,  # triggered by profiler step() call
1101           "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1102           "Optimizer2Step": 100,
1103        }
1104
1105    To figure out the global step count just take the max of dict values (100).
1106
1107    If one of the count increments the max will go up.
1108
1109    .. code-block::
1110
1111        {
1112           "ProfilerStep": 100,
1113           "Optimizer1Step": 101,   # Optimizer1 got incremented first say
1114           "Optimizer2Step": 100,
1115        }
1116
1117    Then global step count is 101
1118    We only call the kineto step() function when global count increments.
1119
1120    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1121    for now. The result could be incorrect increments of the step count.
1122    """
1123
1124    _current_step = 0
1125    _step_dict: Dict[str, int] = defaultdict(int)
1126
1127    @classmethod
1128    def init_step_count(cls, requester: str):
1129        r"""
1130        Initialize for a given requester.
1131        """
1132        cls._step_dict[requester] = cls._current_step
1133
1134    @classmethod
1135    def erase_step_count(cls, requester: str) -> bool:
1136        r"""
1137        Remove a given requester.
1138        """
1139        return cls._step_dict.pop(requester, None) is not None
1140
1141    @classmethod
1142    def increment_step(cls, requester: str) -> int:
1143        """Increments the step count for the requester.
1144
1145        Additionally if the max over all step counts has incremented then
1146        trigger the _kineto_step() returns global step count
1147        """
1148        if requester not in cls._step_dict:
1149            cls.init_step_count(requester)
1150        cls._step_dict[requester] += 1
1151
1152        new_step = max(cls._step_dict.values())
1153        if new_step > cls._current_step:
1154            delta = new_step - cls._current_step
1155            if delta > 1:
1156                warn(
1157                    "Profiler step count has increased more than 1 - "
1158                    f"current_step = {cls._current_step} step dict =  {cls._step_dict}"
1159                )
1160            for _ in range(0, delta):
1161                _kineto_step()
1162            cls._current_step = new_step
1163        return cls._current_step
1164
1165    @classmethod
1166    def current_step(cls) -> int:
1167        r"""
1168        Get the latest step for any requester
1169        """
1170        return cls._current_step
1171