xref: /aosp_15_r20/external/pytorch/torch/profiler/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import gzip
3import json
4import os
5import shutil
6import tempfile
7from abc import ABC, abstractmethod
8from enum import Enum
9from functools import partial
10from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
11from typing_extensions import Self
12from warnings import warn
13
14import torch
15import torch.autograd.profiler as prof
16from torch._C import _get_privateuse1_backend_name
17from torch._C._profiler import (
18    _add_execution_trace_observer,
19    _disable_execution_trace_observer,
20    _enable_execution_trace_observer,
21    _ExperimentalConfig,
22    _remove_execution_trace_observer,
23)
24from torch.autograd import kineto_available, ProfilerActivity
25from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
26
27
28__all__ = [
29    "supported_activities",
30    "ProfilerAction",
31    "schedule",
32    "tensorboard_trace_handler",
33    "profile",
34    "ExecutionTraceObserver",
35]
36PROFILER_STEP_NAME = "ProfilerStep"
37
38
39def supported_activities():
40    """
41    Returns a set of supported profiler tracing activities.
42
43    Note: profiler uses CUPTI library to trace on-device CUDA kernels.
44    In case when CUDA is enabled but CUPTI is not available, passing
45    ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
46    profiling code (same as in the legacy ``torch.autograd.profiler``).
47    This, in turn, results in including CUDA time in the profiler table output,
48    but not in the JSON trace.
49    """
50    return torch.autograd._supported_activities()
51
52
53class _ITraceObserver(ABC):
54    """Abstract interface for a Trace observer.
55    This satisfies 3 methods: start, stop and cleanup"""
56
57    @abstractmethod
58    def start(self):
59        pass
60
61    @abstractmethod
62    def stop(self):
63        pass
64
65    @abstractmethod
66    def cleanup(self):
67        pass
68
69
70class _KinetoProfile:
71    """Low-level profiler wrap the autograd profile
72
73    Args:
74        activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
75            ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
76            ``torch.profiler.ProfilerActivity.XPU``.
77            Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
78            or (when available) ProfilerActivity.XPU.
79        record_shapes (bool): save information about operator's input shapes.
80        profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
81            for more details).
82        with_stack (bool): record source information (file and line number) for the ops.
83        with_flops (bool): use formula to estimate the FLOPS of specific operators
84            (matrix multiplication and 2D convolution).
85        with_modules (bool): record module hierarchy (including function names)
86            corresponding to the callstack of the op. e.g. If module A's forward call's
87            module B's forward which contains an aten::add op,
88            then aten::add's module hierarchy is A.B
89            Note that this support exist, at the moment, only for TorchScript models
90            and not eager mode models.
91        experimental_config (_ExperimentalConfig) : A set of experimental options
92            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
93        execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
94            `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
95            representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
96            When this argument is included the observer start() and stop() will be called for the
97            same time window as PyTorch profiler.
98        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
99
100
101    .. note::
102        This API is experimental and subject to change in the future.
103
104        Enabling shape and stack tracing results in additional overhead.
105        When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
106        that may further prevent certain optimizations that depend on the reference count and introduce
107        extra tensor copies.
108    """
109
110    def __init__(
111        self,
112        *,
113        activities: Optional[Iterable[ProfilerActivity]] = None,
114        record_shapes: bool = False,
115        profile_memory: bool = False,
116        with_stack: bool = False,
117        with_flops: bool = False,
118        with_modules: bool = False,
119        experimental_config: Optional[_ExperimentalConfig] = None,
120        execution_trace_observer: Optional[_ITraceObserver] = None,
121        acc_events: bool = False,
122    ):
123        self.activities = set(activities) if activities else supported_activities()
124        self.record_shapes = record_shapes
125        self.with_flops = with_flops
126        self.profile_memory = profile_memory
127        self.with_stack = with_stack
128        self.with_modules = with_modules
129        self.experimental_config = experimental_config
130        self.execution_trace_observer = execution_trace_observer
131        self.acc_events = acc_events
132        self.profiler: Optional[prof.profile] = None
133        self.mem_tl: Optional[MemoryProfileTimeline] = None
134        self.use_device = None
135        if ProfilerActivity.CUDA in self.activities:
136            self.use_device = "cuda"
137        elif ProfilerActivity.XPU in self.activities:
138            self.use_device = "xpu"
139        elif ProfilerActivity.MTIA in self.activities:
140            self.use_device = "mtia"
141        elif ProfilerActivity.PrivateUse1 in self.activities:
142            self.use_device = _get_privateuse1_backend_name()
143
144        # user-defined metadata to be amended to the trace
145        self.preset_metadata: Dict[str, str] = {}
146
147    def start(self):
148        self.prepare_trace()
149        self.start_trace()
150
151    def stop(self):
152        self.stop_trace()
153
154    def prepare_trace(self):
155        if (self.profiler is None) or (not self.acc_events):
156            self.profiler = prof.profile(
157                use_cpu=(ProfilerActivity.CPU in self.activities),
158                use_device=self.use_device,
159                record_shapes=self.record_shapes,
160                with_flops=self.with_flops,
161                profile_memory=self.profile_memory,
162                with_stack=self.with_stack,
163                with_modules=self.with_modules,
164                use_kineto=True,
165                experimental_config=self.experimental_config,
166                acc_events=self.acc_events,
167            )
168        self.profiler._prepare_trace()
169
170    def start_trace(self):
171        if self.execution_trace_observer:
172            self.execution_trace_observer.start()
173        assert self.profiler is not None
174        self.profiler._start_trace()
175
176        if self.profile_memory:
177            self.add_metadata_json("profile_memory", "1")
178        if self.with_stack:
179            self.add_metadata_json("with_stack", "1")
180        if self.record_shapes:
181            self.add_metadata_json("record_shapes", "1")
182        if self.with_modules:
183            self.add_metadata_json("with_modules", "1")
184        if self.with_flops:
185            self.add_metadata_json("with_flops", "1")
186
187        if kineto_available():
188            dist_info = self._get_distributed_info()
189            if dist_info:
190                self.add_metadata_json("distributedInfo", json.dumps(dist_info))
191
192            if hasattr(torch, "_inductor"):
193                import torch._inductor.config as inductor_config
194
195                if inductor_config.triton.cudagraphs:
196                    os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
197                    self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
198                    # FIXME: CUDA Graph does not work well with CUPTI teardown.
199                    #   1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
200                    #   2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
201                    # Workaround: turn off CUPTI teardown when using CUDA Graphs.
202                    os.environ["TEARDOWN_CUPTI"] = "0"
203
204            # Insert the preset user metadata to the trace
205            for k, v in self.preset_metadata.items():
206                self.add_metadata_json(k, v)
207
208    def stop_trace(self):
209        if self.execution_trace_observer:
210            self.execution_trace_observer.stop()
211        assert self.profiler is not None
212        self.profiler.__exit__(None, None, None)
213
214    def export_chrome_trace(self, path: str):
215        """
216        Exports the collected trace in Chrome JSON format. If kineto is enabled, only
217        last cycle in schedule is exported.
218        """
219        assert self.profiler
220        if path.endswith(".gz"):
221            fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
222            fp.close()
223            retvalue = self.profiler.export_chrome_trace(fp.name)
224            with open(fp.name) as fin:
225                with gzip.open(path, "wt") as fout:
226                    fout.writelines(fin)
227            os.remove(fp.name)
228            return retvalue
229        else:
230            return self.profiler.export_chrome_trace(path)
231
232    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
233        """Save stack traces to a file
234
235        Args:
236            path (str): save stacks file to this location;
237            metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
238        """
239        assert self.profiler
240        return self.profiler.export_stacks(path, metric)
241
242    def toggle_collection_dynamic(
243        self, enable: bool, activities: Iterable[ProfilerActivity]
244    ):
245        """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops
246        (CPU) and CUDA activity supported in Kineto
247
248        Args:
249            activities (iterable): list of activity groups to use in profiling, supported values:
250                ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``
251        Examples:
252
253        .. code-block:: python
254
255            with torch.profiler.profile(
256                activities=[
257                    torch.profiler.ProfilerActivity.CPU,
258                    torch.profiler.ProfilerActivity.CUDA,
259                ]
260            ) as p:
261                code_to_profile_0()
262                // turn off collection of all CUDA activity
263                p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])
264                code_to_profile_1()
265                // turn on collection of all CUDA activity
266                p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])
267                code_to_profile_2()
268            print(p.key_averages().table(
269                sort_by="self_cuda_time_total", row_limit=-1))
270        """
271        if not self.profiler:
272            return
273        self.profiler.toggle_collection_dynamic(enable, activities)
274
275    def key_averages(
276        self, group_by_input_shape: bool = False, group_by_stack_n: int = 0
277    ):
278        """Averages events, grouping them by operator name and (optionally) input shapes and
279        stack.
280
281        .. note::
282            To use shape/stack functionality make sure to set record_shapes/with_stack
283            when creating profiler context manager.
284        """
285        assert self.profiler
286        return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)
287
288    def events(self):
289        """
290        Returns the list of unaggregated profiler events,
291        to be used in the trace callback or after the profiling is finished
292        """
293        assert self.profiler
294        return self.profiler.function_events
295
296    def add_metadata(self, key: str, value: str):
297        """
298        Adds a user defined metadata with a string key and a string value
299        into the trace file
300        """
301        wrapped_value = '"' + value.replace('"', '\\"') + '"'
302        torch.autograd._add_metadata_json(key, wrapped_value)
303
304    def add_metadata_json(self, key: str, value: str):
305        """
306        Adds a user defined metadata with a string key and a valid json value
307        into the trace file
308        """
309        torch.autograd._add_metadata_json(key, value)
310
311    def preset_metadata_json(self, key: str, value: str):
312        """
313        Preset a user defined metadata when the profiler is not started
314        and added into the trace file later.
315        Metadata is in the format of a string key and a valid json value
316        """
317        self.preset_metadata[key] = value
318
319    def _get_distributed_info(self):
320        import torch.distributed as dist
321
322        if not dist.is_available() or not dist.is_initialized():
323            return None
324
325        backend = dist.get_backend()
326        dist_info = {
327            "backend": backend,
328            "rank": dist.get_rank(),
329            "world_size": dist.get_world_size(),
330            "pg_count": dist.get_pg_count(),
331            "pg_config": dist.distributed_c10d._get_all_pg_configs(),
332        }
333        if backend == "nccl":
334            nccl_version = torch.cuda.nccl.version()
335            dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
336        return dist_info
337
338    def _memory_profile(self) -> MemoryProfile:
339        required = ("record_shapes", "profile_memory", "with_stack")
340        missing = [f"{i}=True" for i in required if not getattr(self, i)]
341        if missing:
342            raise ValueError(f"{', '.join(missing)} required for memory profiling.")
343
344        assert self.profiler is not None and self.profiler.kineto_results is not None
345        return MemoryProfile(self.profiler.kineto_results)
346
347    def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
348        """Export memory event information from the profiler collected
349        tree for a given device, and export a timeline plot. There are 3
350        exportable files using ``export_memory_timeline``, each controlled by the
351        ``path``'s suffix.
352
353        - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
354          plot will be embedded as a PNG file in the HTML file.
355
356        - For plot points consisting of ``[times, [sizes by category]]``, where
357          ``times`` are timestamps and ``sizes`` are memory usage for each category.
358          The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
359          (``.json.gz``) depending on the suffix.
360
361        - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
362          event will consist of ``(timestamp, action, numbytes, category)``, where
363          ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
364          and ``category`` is one of the enums from
365          ``torch.profiler._memory_profiler.Category``.
366
367        Output: Memory timeline written as gzipped JSON, JSON, or HTML.
368        """
369        # Default to device 0, if unset. Fallback on cpu.
370        if device is None and self.use_device and self.use_device != "cuda":
371            device = self.use_device + ":0"
372
373        if device is None:
374            device = "cuda:0" if torch.cuda.is_available() else "cpu"
375
376        # Construct the memory timeline plot data
377        self.mem_tl = MemoryProfileTimeline(self._memory_profile())
378
379        # Depending on the file suffix, save the data as json.gz or json.
380        # For html, we can embed the image into an HTML file.
381        if path.endswith(".html"):
382            self.mem_tl.export_memory_timeline_html(path, device)
383        elif path.endswith(".gz"):
384            fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
385            fp.close()
386            if path.endswith("raw.json.gz"):
387                self.mem_tl.export_memory_timeline_raw(fp.name, device)
388            else:
389                self.mem_tl.export_memory_timeline(fp.name, device)
390            with open(fp.name) as fin:
391                with gzip.open(path, "wt") as fout:
392                    fout.writelines(fin)
393            os.remove(fp.name)
394        else:
395            self.mem_tl.export_memory_timeline(path, device)
396
397
398class ProfilerAction(Enum):
399    """
400    Profiler actions that can be taken at the specified intervals
401    """
402
403    NONE = 0
404    WARMUP = 1
405    RECORD = 2
406    RECORD_AND_SAVE = 3
407
408
409def schedule(
410    *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0
411) -> Callable:
412    """
413    Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
414    the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
415    then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
416    The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
417    the cycles will continue until the profiling is finished.
418    """
419
420    def schedule_fn(step: int) -> ProfilerAction:
421        assert step >= 0
422        if step < skip_first:
423            return ProfilerAction.NONE
424        else:
425            step -= skip_first
426        num_steps = wait + warmup + active
427        if repeat > 0 and step / num_steps >= repeat:
428            return ProfilerAction.NONE
429        mod_step = step % num_steps
430        if mod_step < wait:
431            return ProfilerAction.NONE
432        elif mod_step < wait + warmup:
433            return ProfilerAction.WARMUP
434        else:
435            return (
436                ProfilerAction.RECORD
437                if mod_step < num_steps - 1
438                else ProfilerAction.RECORD_AND_SAVE
439            )
440
441    assert (
442        wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
443    ), "Invalid profiler schedule arguments"
444    if warmup == 0:
445        warn("Profiler won't be using warmup, this can skew profiler results")
446    return schedule_fn
447
448
449def _default_schedule_fn(_: int) -> ProfilerAction:
450    """
451    Default profiler behavior - immediately starts recording the events,
452    keeps doing it on every profiler step.
453    """
454    return ProfilerAction.RECORD
455
456
457def tensorboard_trace_handler(
458    dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False
459):
460    """
461    Outputs tracing files to directory of ``dir_name``, then that directory can be
462    directly delivered to tensorboard as logdir.
463    ``worker_name`` should be unique for each worker in distributed scenario,
464    it will be set to '[hostname]_[pid]' by default.
465    """
466    import os
467    import socket
468    import time
469
470    def handler_fn(prof) -> None:
471        nonlocal worker_name
472        if not os.path.isdir(dir_name):
473            try:
474                os.makedirs(dir_name, exist_ok=True)
475            except Exception as e:
476                raise RuntimeError("Can't create directory: " + dir_name) from e
477        if not worker_name:
478            worker_name = f"{socket.gethostname()}_{os.getpid()}"
479        # Use nanosecond here to avoid naming clash when exporting the trace
480        file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
481        if use_gzip:
482            file_name = file_name + ".gz"
483        prof.export_chrome_trace(os.path.join(dir_name, file_name))
484
485    return handler_fn
486
487
488class profile(_KinetoProfile):
489    """Profiler context manager.
490
491    Args:
492        activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
493            ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
494            ``torch.profiler.ProfilerActivity.XPU``.
495            Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
496            or (when available) ProfilerActivity.XPU.
497        schedule (Callable): callable that takes step (int) as a single parameter and returns
498            ``ProfilerAction`` value that specifies the profiler action to perform at each step.
499        on_trace_ready (Callable): callable that is called at each step when ``schedule``
500            returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
501        record_shapes (bool): save information about operator's input shapes.
502        profile_memory (bool): track tensor memory allocation/deallocation.
503        with_stack (bool): record source information (file and line number) for the ops.
504        with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
505            (matrix multiplication and 2D convolution).
506        with_modules (bool): record module hierarchy (including function names)
507            corresponding to the callstack of the op. e.g. If module A's forward call's
508            module B's forward which contains an aten::add op,
509            then aten::add's module hierarchy is A.B
510            Note that this support exist, at the moment, only for TorchScript models
511            and not eager mode models.
512        experimental_config (_ExperimentalConfig) : A set of experimental options
513            used for Kineto library features. Note, backward compatibility is not guaranteed.
514        execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
515            `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
516            representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
517            When this argument is included the observer start() and stop() will be called for the
518            same time window as PyTorch profiler. See the examples section below for a code sample.
519        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
520        use_cuda (bool):
521            .. deprecated:: 1.8.1
522                use ``activities`` instead.
523
524    .. note::
525        Use :func:`~torch.profiler.schedule` to generate the callable schedule.
526        Non-default schedules are useful when profiling long training jobs
527        and allow the user to obtain multiple traces at the different iterations
528        of the training process.
529        The default schedule simply records all the events continuously for the
530        duration of the context manager.
531
532    .. note::
533        Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
534
535        ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
536
537        After profiling, result files can be found in the specified directory. Use the command:
538
539        ``tensorboard --logdir dir_name``
540
541        to see the results in TensorBoard.
542        For more information, see
543        `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
544
545    .. note::
546        Enabling shape and stack tracing results in additional overhead.
547        When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
548        that may further prevent certain optimizations that depend on the reference count and introduce
549        extra tensor copies.
550
551
552    Examples:
553
554    .. code-block:: python
555
556        with torch.profiler.profile(
557            activities=[
558                torch.profiler.ProfilerActivity.CPU,
559                torch.profiler.ProfilerActivity.CUDA,
560            ]
561        ) as p:
562            code_to_profile()
563        print(p.key_averages().table(
564            sort_by="self_cuda_time_total", row_limit=-1))
565
566    Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
567
568    .. code-block:: python
569
570        # Non-default profiler schedule allows user to turn profiler on and off
571        # on different iterations of the training loop;
572        # trace_handler is called every time a new trace becomes available
573        def trace_handler(prof):
574            print(prof.key_averages().table(
575                sort_by="self_cuda_time_total", row_limit=-1))
576            # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
577
578        with torch.profiler.profile(
579            activities=[
580                torch.profiler.ProfilerActivity.CPU,
581                torch.profiler.ProfilerActivity.CUDA,
582            ],
583
584            # In this example with wait=1, warmup=1, active=2, repeat=1,
585            # profiler will skip the first step/iteration,
586            # start warming up on the second, record
587            # the third and the forth iterations,
588            # after which the trace will become available
589            # and on_trace_ready (when set) is called;
590            # the cycle repeats starting with the next step
591
592            schedule=torch.profiler.schedule(
593                wait=1,
594                warmup=1,
595                active=2,
596                repeat=1),
597            on_trace_ready=trace_handler
598            # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
599            # used when outputting for tensorboard
600            ) as p:
601                for iter in range(N):
602                    code_iteration_to_profile(iter)
603                    # send a signal to the profiler that the next iteration has started
604                    p.step()
605
606    The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
607
608    .. code-block:: python
609
610        with torch.profiler.profile(
611            ...
612            execution_trace_observer=(
613                ExecutionTraceObserver().register_callback("./execution_trace.json")
614            ),
615        ) as p:
616            for iter in range(N):
617                code_iteration_to_profile(iter)
618                p.step()
619
620    You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py.
621    Note: One can also pass any object satisfying the _ITraceObserver interface.
622    """
623
624    def __init__(
625        self,
626        *,
627        activities: Optional[Iterable[ProfilerActivity]] = None,
628        schedule: Optional[Callable[[int], ProfilerAction]] = None,
629        on_trace_ready: Optional[Callable[..., Any]] = None,
630        record_shapes: bool = False,
631        profile_memory: bool = False,
632        with_stack: bool = False,
633        with_flops: bool = False,
634        with_modules: bool = False,
635        experimental_config: Optional[_ExperimentalConfig] = None,
636        execution_trace_observer: Optional[_ITraceObserver] = None,
637        acc_events: bool = False,
638        # deprecated:
639        use_cuda: Optional[bool] = None,
640    ):
641        activities_set = set(activities) if activities else supported_activities()
642        if use_cuda is not None:
643            warn(
644                "`use_cuda` is deprecated, use `activities` argument instead",
645                FutureWarning,
646                stacklevel=2,
647            )
648            if use_cuda:
649                activities_set.add(ProfilerActivity.CUDA)
650            elif ProfilerActivity.CUDA in activities_set:
651                activities_set.remove(ProfilerActivity.CUDA)
652        assert len(activities_set) > 0, "No valid profiler activities found"
653
654        super().__init__(
655            activities=activities,
656            record_shapes=record_shapes,
657            profile_memory=profile_memory,
658            with_stack=with_stack,
659            with_flops=with_flops,
660            with_modules=with_modules,
661            experimental_config=experimental_config,
662            execution_trace_observer=execution_trace_observer,
663            acc_events=acc_events,
664        )
665
666        if schedule:
667            self.schedule = schedule
668            # add step markers into the trace and table view
669            self.record_steps = True
670        else:
671            self.schedule = _default_schedule_fn
672            self.record_steps = False
673        self.on_trace_ready = on_trace_ready
674        self.step_num = 0
675        self.current_action = self.schedule(self.step_num)
676        self.step_rec_fn: Optional[prof.record_function] = None
677
678        self.action_map: Dict[
679            Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]
680        ] = {
681            # key is (prev_action, current_action), value is action list corresponding to the state pair.
682            (ProfilerAction.NONE, ProfilerAction.NONE): [],
683            (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
684            (ProfilerAction.NONE, ProfilerAction.RECORD): [
685                self.prepare_trace,
686                self.start_trace,
687            ],
688            (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [
689                self.prepare_trace,
690                self.start_trace,
691            ],
692            (ProfilerAction.WARMUP, ProfilerAction.NONE): [
693                partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
694                self.start_trace,
695                self.stop_trace,
696            ],
697            (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
698            (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
699            (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
700            (ProfilerAction.RECORD, ProfilerAction.NONE): [
701                partial(warn, "Incorrect schedule: RECORD followed by NONE"),
702                self.stop_trace,
703            ],
704            (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
705                partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
706                self.stop_trace,
707            ],
708            (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
709            (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
710            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [
711                self.stop_trace,
712                self._trace_ready,
713            ],
714            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [
715                self.stop_trace,
716                self._trace_ready,
717                self.prepare_trace,
718            ],
719            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
720                self.stop_trace,
721                self._trace_ready,
722                self.prepare_trace,
723                self.start_trace,
724            ],
725            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
726                self.stop_trace,
727                self._trace_ready,
728                self.prepare_trace,
729                self.start_trace,
730            ],
731            # used for exit action
732            (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
733            (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
734            (ProfilerAction.RECORD_AND_SAVE, None): [
735                self.stop_trace,
736                self._trace_ready,
737            ],
738        }
739        # Start tracking increments to profiler step, this will be used
740        # by Kineto
741        prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
742
743    def __enter__(self):
744        self.start()
745        return self
746
747    def __exit__(self, exc_type, exc_val, exc_tb):
748        self.stop()
749        prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
750        if self.execution_trace_observer:
751            self.execution_trace_observer.cleanup()
752
753    def start(self):
754        self._transit_action(ProfilerAction.NONE, self.current_action)
755        if self.record_steps:
756            self.step_rec_fn = prof.record_function(
757                "ProfilerStep#" + str(self.step_num)
758            )
759            self.step_rec_fn.__enter__()
760
761    def stop(self):
762        if self.record_steps and self.step_rec_fn:
763            self.step_rec_fn.__exit__(None, None, None)
764        self._transit_action(self.current_action, None)
765
766    def step(self):
767        """
768        Signals the profiler that the next profiling step has started.
769        """
770        if self.record_steps and self.step_rec_fn:
771            self.step_rec_fn.__exit__(None, None, None)
772        prev_action = self.current_action
773        self.step_num += 1
774        self.current_action = self.schedule(self.step_num)
775
776        self._transit_action(prev_action, self.current_action)
777        prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
778
779        if self.record_steps:
780            self.step_rec_fn = prof.record_function(
781                "ProfilerStep#" + str(self.step_num)
782            )
783            self.step_rec_fn.__enter__()
784
785    def _trace_ready(self):
786        if self.on_trace_ready:
787            self.on_trace_ready(self)
788
789    def _transit_action(self, prev_action, current_action):
790        action_list = self.action_map.get((prev_action, current_action))
791        if action_list:
792            for action in action_list:
793                action()
794
795    def _stats(self) -> Optional[prof._ProfilerStats]:
796        if self.profiler is None:
797            return None
798        return self.profiler._stats
799
800
801class ExecutionTraceObserver(_ITraceObserver):
802    """Execution Trace Observer
803
804    Each process can have a single ExecutionTraceObserver instance. The observer
805    can be added to record function callbacks via calling register_callback()
806    explicitly. Without calling unregister_callback(), repeated calls to
807    register_callback() will not add additional observers to record function
808    callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
809    methods control when the event data is recorded.
810
811    Deleting or calling unregister_callback() will remove the observer from the
812    record function callbacks, finalize the output file, and will stop
813    incurring any overheads.
814    """
815
816    def __init__(self) -> None:
817        """
818        Initializes the default states.
819        """
820        self._registered = False
821        self._execution_trace_running = False
822
823    def __del__(self):
824        """
825        Calls unregister_callback() to make sure to finalize outputs.
826        """
827        self.unregister_callback()
828
829    def register_callback(self, output_file_path: str) -> Self:
830        """
831        Adds ET observer to record function callbacks. The data will be
832        written to output_file_path.
833        """
834        if not self._registered:
835            self._output_file_path = output_file_path
836            self._registered = _add_execution_trace_observer(output_file_path)
837        return self
838
839    def unregister_callback(self):
840        """
841        Removes ET observer from record function callbacks.
842        """
843
844        def _save_triton_kernels():
845            # Save the kernel paths for the generated kernels
846            from torch._inductor.codecache import PyCodeCache as PyCodeCache
847
848            kernel_files = [
849                v.__file__
850                for v in PyCodeCache.cache.values()
851                if getattr(v, "__file__", None) is not None
852            ]
853            work_dir, file_name = os.path.split(self._output_file_path)
854            resource_dir = os.path.join(
855                work_dir, os.path.splitext(file_name)[0] + "_resources"
856            )
857            if not os.path.exists(resource_dir):
858                os.mkdir(resource_dir)
859
860            for kernel_file in kernel_files:
861                if kernel_file is None:
862                    continue
863                path, name = os.path.split(kernel_file)
864                dst = os.path.join(resource_dir, name)
865                shutil.copyfile(kernel_file, dst)
866
867        if self._registered:
868            self.stop()
869            try:
870                _save_triton_kernels()
871            except Exception as e:
872                warn(f"Execution trace failed to save kernels: {e}")
873            _remove_execution_trace_observer()
874            self._registered = False
875
876    @property
877    def is_registered(self):
878        """
879        Returns True if the execution trace observer is registered, otherwise False.
880        """
881        return self._registered
882
883    def is_running(self):
884        """
885        Returns True if the observer is running, otherwise False.
886        """
887        return self._execution_trace_running
888
889    def start(self):
890        """
891        Starts to capture.
892        """
893        if self._registered and not self._execution_trace_running:
894            _enable_execution_trace_observer()
895            self._execution_trace_running = True
896            self._record_pg_config()
897
898    def stop(self):
899        """
900        Stops to capture.
901        """
902        if self._execution_trace_running:
903            _disable_execution_trace_observer()
904            self._execution_trace_running = False
905
906    def cleanup(self):
907        """
908        Calls unregister_callback() to make sure to finalize outputs.
909        """
910        self.unregister_callback()
911
912    def get_output_file_path(self) -> str:
913        """
914        Returns the output file name.
915        """
916        if self.is_registered:
917            return self._output_file_path
918        else:
919            raise RuntimeError(
920                "A callback to the ET profiler needs to be registered "
921                "first before getting the output file path"
922            )
923
924    def _record_pg_config(self) -> None:
925        # Records the PG config info to the trace as node:
926        #  ## process_group:init ##
927        if (
928            self.is_registered
929            and torch.distributed.is_available()
930            and torch.distributed.is_initialized()
931        ):
932            pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
933            torch.autograd._record_function_with_args_enter(
934                "## process_group:init ##", json.dumps(pg_config_info)
935            )
936