xref: /aosp_15_r20/external/pytorch/torch/profiler/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport gzip
3*da0073e9SAndroid Build Coastguard Workerimport json
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport shutil
6*da0073e9SAndroid Build Coastguard Workerimport tempfile
7*da0073e9SAndroid Build Coastguard Workerfrom abc import ABC, abstractmethod
8*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
9*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
11*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self
12*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.profiler as prof
16*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _get_privateuse1_backend_name
17*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import (
18*da0073e9SAndroid Build Coastguard Worker    _add_execution_trace_observer,
19*da0073e9SAndroid Build Coastguard Worker    _disable_execution_trace_observer,
20*da0073e9SAndroid Build Coastguard Worker    _enable_execution_trace_observer,
21*da0073e9SAndroid Build Coastguard Worker    _ExperimentalConfig,
22*da0073e9SAndroid Build Coastguard Worker    _remove_execution_trace_observer,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import kineto_available, ProfilerActivity
25*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker__all__ = [
29*da0073e9SAndroid Build Coastguard Worker    "supported_activities",
30*da0073e9SAndroid Build Coastguard Worker    "ProfilerAction",
31*da0073e9SAndroid Build Coastguard Worker    "schedule",
32*da0073e9SAndroid Build Coastguard Worker    "tensorboard_trace_handler",
33*da0073e9SAndroid Build Coastguard Worker    "profile",
34*da0073e9SAndroid Build Coastguard Worker    "ExecutionTraceObserver",
35*da0073e9SAndroid Build Coastguard Worker]
36*da0073e9SAndroid Build Coastguard WorkerPROFILER_STEP_NAME = "ProfilerStep"
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerdef supported_activities():
40*da0073e9SAndroid Build Coastguard Worker    """
41*da0073e9SAndroid Build Coastguard Worker    Returns a set of supported profiler tracing activities.
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    Note: profiler uses CUPTI library to trace on-device CUDA kernels.
44*da0073e9SAndroid Build Coastguard Worker    In case when CUDA is enabled but CUPTI is not available, passing
45*da0073e9SAndroid Build Coastguard Worker    ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
46*da0073e9SAndroid Build Coastguard Worker    profiling code (same as in the legacy ``torch.autograd.profiler``).
47*da0073e9SAndroid Build Coastguard Worker    This, in turn, results in including CUDA time in the profiler table output,
48*da0073e9SAndroid Build Coastguard Worker    but not in the JSON trace.
49*da0073e9SAndroid Build Coastguard Worker    """
50*da0073e9SAndroid Build Coastguard Worker    return torch.autograd._supported_activities()
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerclass _ITraceObserver(ABC):
54*da0073e9SAndroid Build Coastguard Worker    """Abstract interface for a Trace observer.
55*da0073e9SAndroid Build Coastguard Worker    This satisfies 3 methods: start, stop and cleanup"""
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    @abstractmethod
58*da0073e9SAndroid Build Coastguard Worker    def start(self):
59*da0073e9SAndroid Build Coastguard Worker        pass
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    @abstractmethod
62*da0073e9SAndroid Build Coastguard Worker    def stop(self):
63*da0073e9SAndroid Build Coastguard Worker        pass
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    @abstractmethod
66*da0073e9SAndroid Build Coastguard Worker    def cleanup(self):
67*da0073e9SAndroid Build Coastguard Worker        pass
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerclass _KinetoProfile:
71*da0073e9SAndroid Build Coastguard Worker    """Low-level profiler wrap the autograd profile
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    Args:
74*da0073e9SAndroid Build Coastguard Worker        activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
75*da0073e9SAndroid Build Coastguard Worker            ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
76*da0073e9SAndroid Build Coastguard Worker            ``torch.profiler.ProfilerActivity.XPU``.
77*da0073e9SAndroid Build Coastguard Worker            Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
78*da0073e9SAndroid Build Coastguard Worker            or (when available) ProfilerActivity.XPU.
79*da0073e9SAndroid Build Coastguard Worker        record_shapes (bool): save information about operator's input shapes.
80*da0073e9SAndroid Build Coastguard Worker        profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
81*da0073e9SAndroid Build Coastguard Worker            for more details).
82*da0073e9SAndroid Build Coastguard Worker        with_stack (bool): record source information (file and line number) for the ops.
83*da0073e9SAndroid Build Coastguard Worker        with_flops (bool): use formula to estimate the FLOPS of specific operators
84*da0073e9SAndroid Build Coastguard Worker            (matrix multiplication and 2D convolution).
85*da0073e9SAndroid Build Coastguard Worker        with_modules (bool): record module hierarchy (including function names)
86*da0073e9SAndroid Build Coastguard Worker            corresponding to the callstack of the op. e.g. If module A's forward call's
87*da0073e9SAndroid Build Coastguard Worker            module B's forward which contains an aten::add op,
88*da0073e9SAndroid Build Coastguard Worker            then aten::add's module hierarchy is A.B
89*da0073e9SAndroid Build Coastguard Worker            Note that this support exist, at the moment, only for TorchScript models
90*da0073e9SAndroid Build Coastguard Worker            and not eager mode models.
91*da0073e9SAndroid Build Coastguard Worker        experimental_config (_ExperimentalConfig) : A set of experimental options
92*da0073e9SAndroid Build Coastguard Worker            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
93*da0073e9SAndroid Build Coastguard Worker        execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
94*da0073e9SAndroid Build Coastguard Worker            `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
95*da0073e9SAndroid Build Coastguard Worker            representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
96*da0073e9SAndroid Build Coastguard Worker            When this argument is included the observer start() and stop() will be called for the
97*da0073e9SAndroid Build Coastguard Worker            same time window as PyTorch profiler.
98*da0073e9SAndroid Build Coastguard Worker        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    .. note::
102*da0073e9SAndroid Build Coastguard Worker        This API is experimental and subject to change in the future.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        Enabling shape and stack tracing results in additional overhead.
105*da0073e9SAndroid Build Coastguard Worker        When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
106*da0073e9SAndroid Build Coastguard Worker        that may further prevent certain optimizations that depend on the reference count and introduce
107*da0073e9SAndroid Build Coastguard Worker        extra tensor copies.
108*da0073e9SAndroid Build Coastguard Worker    """
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def __init__(
111*da0073e9SAndroid Build Coastguard Worker        self,
112*da0073e9SAndroid Build Coastguard Worker        *,
113*da0073e9SAndroid Build Coastguard Worker        activities: Optional[Iterable[ProfilerActivity]] = None,
114*da0073e9SAndroid Build Coastguard Worker        record_shapes: bool = False,
115*da0073e9SAndroid Build Coastguard Worker        profile_memory: bool = False,
116*da0073e9SAndroid Build Coastguard Worker        with_stack: bool = False,
117*da0073e9SAndroid Build Coastguard Worker        with_flops: bool = False,
118*da0073e9SAndroid Build Coastguard Worker        with_modules: bool = False,
119*da0073e9SAndroid Build Coastguard Worker        experimental_config: Optional[_ExperimentalConfig] = None,
120*da0073e9SAndroid Build Coastguard Worker        execution_trace_observer: Optional[_ITraceObserver] = None,
121*da0073e9SAndroid Build Coastguard Worker        acc_events: bool = False,
122*da0073e9SAndroid Build Coastguard Worker    ):
123*da0073e9SAndroid Build Coastguard Worker        self.activities = set(activities) if activities else supported_activities()
124*da0073e9SAndroid Build Coastguard Worker        self.record_shapes = record_shapes
125*da0073e9SAndroid Build Coastguard Worker        self.with_flops = with_flops
126*da0073e9SAndroid Build Coastguard Worker        self.profile_memory = profile_memory
127*da0073e9SAndroid Build Coastguard Worker        self.with_stack = with_stack
128*da0073e9SAndroid Build Coastguard Worker        self.with_modules = with_modules
129*da0073e9SAndroid Build Coastguard Worker        self.experimental_config = experimental_config
130*da0073e9SAndroid Build Coastguard Worker        self.execution_trace_observer = execution_trace_observer
131*da0073e9SAndroid Build Coastguard Worker        self.acc_events = acc_events
132*da0073e9SAndroid Build Coastguard Worker        self.profiler: Optional[prof.profile] = None
133*da0073e9SAndroid Build Coastguard Worker        self.mem_tl: Optional[MemoryProfileTimeline] = None
134*da0073e9SAndroid Build Coastguard Worker        self.use_device = None
135*da0073e9SAndroid Build Coastguard Worker        if ProfilerActivity.CUDA in self.activities:
136*da0073e9SAndroid Build Coastguard Worker            self.use_device = "cuda"
137*da0073e9SAndroid Build Coastguard Worker        elif ProfilerActivity.XPU in self.activities:
138*da0073e9SAndroid Build Coastguard Worker            self.use_device = "xpu"
139*da0073e9SAndroid Build Coastguard Worker        elif ProfilerActivity.MTIA in self.activities:
140*da0073e9SAndroid Build Coastguard Worker            self.use_device = "mtia"
141*da0073e9SAndroid Build Coastguard Worker        elif ProfilerActivity.PrivateUse1 in self.activities:
142*da0073e9SAndroid Build Coastguard Worker            self.use_device = _get_privateuse1_backend_name()
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # user-defined metadata to be amended to the trace
145*da0073e9SAndroid Build Coastguard Worker        self.preset_metadata: Dict[str, str] = {}
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    def start(self):
148*da0073e9SAndroid Build Coastguard Worker        self.prepare_trace()
149*da0073e9SAndroid Build Coastguard Worker        self.start_trace()
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def stop(self):
152*da0073e9SAndroid Build Coastguard Worker        self.stop_trace()
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    def prepare_trace(self):
155*da0073e9SAndroid Build Coastguard Worker        if (self.profiler is None) or (not self.acc_events):
156*da0073e9SAndroid Build Coastguard Worker            self.profiler = prof.profile(
157*da0073e9SAndroid Build Coastguard Worker                use_cpu=(ProfilerActivity.CPU in self.activities),
158*da0073e9SAndroid Build Coastguard Worker                use_device=self.use_device,
159*da0073e9SAndroid Build Coastguard Worker                record_shapes=self.record_shapes,
160*da0073e9SAndroid Build Coastguard Worker                with_flops=self.with_flops,
161*da0073e9SAndroid Build Coastguard Worker                profile_memory=self.profile_memory,
162*da0073e9SAndroid Build Coastguard Worker                with_stack=self.with_stack,
163*da0073e9SAndroid Build Coastguard Worker                with_modules=self.with_modules,
164*da0073e9SAndroid Build Coastguard Worker                use_kineto=True,
165*da0073e9SAndroid Build Coastguard Worker                experimental_config=self.experimental_config,
166*da0073e9SAndroid Build Coastguard Worker                acc_events=self.acc_events,
167*da0073e9SAndroid Build Coastguard Worker            )
168*da0073e9SAndroid Build Coastguard Worker        self.profiler._prepare_trace()
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def start_trace(self):
171*da0073e9SAndroid Build Coastguard Worker        if self.execution_trace_observer:
172*da0073e9SAndroid Build Coastguard Worker            self.execution_trace_observer.start()
173*da0073e9SAndroid Build Coastguard Worker        assert self.profiler is not None
174*da0073e9SAndroid Build Coastguard Worker        self.profiler._start_trace()
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        if self.profile_memory:
177*da0073e9SAndroid Build Coastguard Worker            self.add_metadata_json("profile_memory", "1")
178*da0073e9SAndroid Build Coastguard Worker        if self.with_stack:
179*da0073e9SAndroid Build Coastguard Worker            self.add_metadata_json("with_stack", "1")
180*da0073e9SAndroid Build Coastguard Worker        if self.record_shapes:
181*da0073e9SAndroid Build Coastguard Worker            self.add_metadata_json("record_shapes", "1")
182*da0073e9SAndroid Build Coastguard Worker        if self.with_modules:
183*da0073e9SAndroid Build Coastguard Worker            self.add_metadata_json("with_modules", "1")
184*da0073e9SAndroid Build Coastguard Worker        if self.with_flops:
185*da0073e9SAndroid Build Coastguard Worker            self.add_metadata_json("with_flops", "1")
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        if kineto_available():
188*da0073e9SAndroid Build Coastguard Worker            dist_info = self._get_distributed_info()
189*da0073e9SAndroid Build Coastguard Worker            if dist_info:
190*da0073e9SAndroid Build Coastguard Worker                self.add_metadata_json("distributedInfo", json.dumps(dist_info))
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch, "_inductor"):
193*da0073e9SAndroid Build Coastguard Worker                import torch._inductor.config as inductor_config
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker                if inductor_config.triton.cudagraphs:
196*da0073e9SAndroid Build Coastguard Worker                    os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
197*da0073e9SAndroid Build Coastguard Worker                    self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
198*da0073e9SAndroid Build Coastguard Worker                    # FIXME: CUDA Graph does not work well with CUPTI teardown.
199*da0073e9SAndroid Build Coastguard Worker                    #   1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
200*da0073e9SAndroid Build Coastguard Worker                    #   2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
201*da0073e9SAndroid Build Coastguard Worker                    # Workaround: turn off CUPTI teardown when using CUDA Graphs.
202*da0073e9SAndroid Build Coastguard Worker                    os.environ["TEARDOWN_CUPTI"] = "0"
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker            # Insert the preset user metadata to the trace
205*da0073e9SAndroid Build Coastguard Worker            for k, v in self.preset_metadata.items():
206*da0073e9SAndroid Build Coastguard Worker                self.add_metadata_json(k, v)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def stop_trace(self):
209*da0073e9SAndroid Build Coastguard Worker        if self.execution_trace_observer:
210*da0073e9SAndroid Build Coastguard Worker            self.execution_trace_observer.stop()
211*da0073e9SAndroid Build Coastguard Worker        assert self.profiler is not None
212*da0073e9SAndroid Build Coastguard Worker        self.profiler.__exit__(None, None, None)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    def export_chrome_trace(self, path: str):
215*da0073e9SAndroid Build Coastguard Worker        """
216*da0073e9SAndroid Build Coastguard Worker        Exports the collected trace in Chrome JSON format. If kineto is enabled, only
217*da0073e9SAndroid Build Coastguard Worker        last cycle in schedule is exported.
218*da0073e9SAndroid Build Coastguard Worker        """
219*da0073e9SAndroid Build Coastguard Worker        assert self.profiler
220*da0073e9SAndroid Build Coastguard Worker        if path.endswith(".gz"):
221*da0073e9SAndroid Build Coastguard Worker            fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
222*da0073e9SAndroid Build Coastguard Worker            fp.close()
223*da0073e9SAndroid Build Coastguard Worker            retvalue = self.profiler.export_chrome_trace(fp.name)
224*da0073e9SAndroid Build Coastguard Worker            with open(fp.name) as fin:
225*da0073e9SAndroid Build Coastguard Worker                with gzip.open(path, "wt") as fout:
226*da0073e9SAndroid Build Coastguard Worker                    fout.writelines(fin)
227*da0073e9SAndroid Build Coastguard Worker            os.remove(fp.name)
228*da0073e9SAndroid Build Coastguard Worker            return retvalue
229*da0073e9SAndroid Build Coastguard Worker        else:
230*da0073e9SAndroid Build Coastguard Worker            return self.profiler.export_chrome_trace(path)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
233*da0073e9SAndroid Build Coastguard Worker        """Save stack traces to a file
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        Args:
236*da0073e9SAndroid Build Coastguard Worker            path (str): save stacks file to this location;
237*da0073e9SAndroid Build Coastguard Worker            metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
238*da0073e9SAndroid Build Coastguard Worker        """
239*da0073e9SAndroid Build Coastguard Worker        assert self.profiler
240*da0073e9SAndroid Build Coastguard Worker        return self.profiler.export_stacks(path, metric)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def toggle_collection_dynamic(
243*da0073e9SAndroid Build Coastguard Worker        self, enable: bool, activities: Iterable[ProfilerActivity]
244*da0073e9SAndroid Build Coastguard Worker    ):
245*da0073e9SAndroid Build Coastguard Worker        """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops
246*da0073e9SAndroid Build Coastguard Worker        (CPU) and CUDA activity supported in Kineto
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        Args:
249*da0073e9SAndroid Build Coastguard Worker            activities (iterable): list of activity groups to use in profiling, supported values:
250*da0073e9SAndroid Build Coastguard Worker                ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``
251*da0073e9SAndroid Build Coastguard Worker        Examples:
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.profile(
256*da0073e9SAndroid Build Coastguard Worker                activities=[
257*da0073e9SAndroid Build Coastguard Worker                    torch.profiler.ProfilerActivity.CPU,
258*da0073e9SAndroid Build Coastguard Worker                    torch.profiler.ProfilerActivity.CUDA,
259*da0073e9SAndroid Build Coastguard Worker                ]
260*da0073e9SAndroid Build Coastguard Worker            ) as p:
261*da0073e9SAndroid Build Coastguard Worker                code_to_profile_0()
262*da0073e9SAndroid Build Coastguard Worker                // turn off collection of all CUDA activity
263*da0073e9SAndroid Build Coastguard Worker                p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])
264*da0073e9SAndroid Build Coastguard Worker                code_to_profile_1()
265*da0073e9SAndroid Build Coastguard Worker                // turn on collection of all CUDA activity
266*da0073e9SAndroid Build Coastguard Worker                p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])
267*da0073e9SAndroid Build Coastguard Worker                code_to_profile_2()
268*da0073e9SAndroid Build Coastguard Worker            print(p.key_averages().table(
269*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cuda_time_total", row_limit=-1))
270*da0073e9SAndroid Build Coastguard Worker        """
271*da0073e9SAndroid Build Coastguard Worker        if not self.profiler:
272*da0073e9SAndroid Build Coastguard Worker            return
273*da0073e9SAndroid Build Coastguard Worker        self.profiler.toggle_collection_dynamic(enable, activities)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    def key_averages(
276*da0073e9SAndroid Build Coastguard Worker        self, group_by_input_shape: bool = False, group_by_stack_n: int = 0
277*da0073e9SAndroid Build Coastguard Worker    ):
278*da0073e9SAndroid Build Coastguard Worker        """Averages events, grouping them by operator name and (optionally) input shapes and
279*da0073e9SAndroid Build Coastguard Worker        stack.
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        .. note::
282*da0073e9SAndroid Build Coastguard Worker            To use shape/stack functionality make sure to set record_shapes/with_stack
283*da0073e9SAndroid Build Coastguard Worker            when creating profiler context manager.
284*da0073e9SAndroid Build Coastguard Worker        """
285*da0073e9SAndroid Build Coastguard Worker        assert self.profiler
286*da0073e9SAndroid Build Coastguard Worker        return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def events(self):
289*da0073e9SAndroid Build Coastguard Worker        """
290*da0073e9SAndroid Build Coastguard Worker        Returns the list of unaggregated profiler events,
291*da0073e9SAndroid Build Coastguard Worker        to be used in the trace callback or after the profiling is finished
292*da0073e9SAndroid Build Coastguard Worker        """
293*da0073e9SAndroid Build Coastguard Worker        assert self.profiler
294*da0073e9SAndroid Build Coastguard Worker        return self.profiler.function_events
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def add_metadata(self, key: str, value: str):
297*da0073e9SAndroid Build Coastguard Worker        """
298*da0073e9SAndroid Build Coastguard Worker        Adds a user defined metadata with a string key and a string value
299*da0073e9SAndroid Build Coastguard Worker        into the trace file
300*da0073e9SAndroid Build Coastguard Worker        """
301*da0073e9SAndroid Build Coastguard Worker        wrapped_value = '"' + value.replace('"', '\\"') + '"'
302*da0073e9SAndroid Build Coastguard Worker        torch.autograd._add_metadata_json(key, wrapped_value)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    def add_metadata_json(self, key: str, value: str):
305*da0073e9SAndroid Build Coastguard Worker        """
306*da0073e9SAndroid Build Coastguard Worker        Adds a user defined metadata with a string key and a valid json value
307*da0073e9SAndroid Build Coastguard Worker        into the trace file
308*da0073e9SAndroid Build Coastguard Worker        """
309*da0073e9SAndroid Build Coastguard Worker        torch.autograd._add_metadata_json(key, value)
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    def preset_metadata_json(self, key: str, value: str):
312*da0073e9SAndroid Build Coastguard Worker        """
313*da0073e9SAndroid Build Coastguard Worker        Preset a user defined metadata when the profiler is not started
314*da0073e9SAndroid Build Coastguard Worker        and added into the trace file later.
315*da0073e9SAndroid Build Coastguard Worker        Metadata is in the format of a string key and a valid json value
316*da0073e9SAndroid Build Coastguard Worker        """
317*da0073e9SAndroid Build Coastguard Worker        self.preset_metadata[key] = value
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def _get_distributed_info(self):
320*da0073e9SAndroid Build Coastguard Worker        import torch.distributed as dist
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        if not dist.is_available() or not dist.is_initialized():
323*da0073e9SAndroid Build Coastguard Worker            return None
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        backend = dist.get_backend()
326*da0073e9SAndroid Build Coastguard Worker        dist_info = {
327*da0073e9SAndroid Build Coastguard Worker            "backend": backend,
328*da0073e9SAndroid Build Coastguard Worker            "rank": dist.get_rank(),
329*da0073e9SAndroid Build Coastguard Worker            "world_size": dist.get_world_size(),
330*da0073e9SAndroid Build Coastguard Worker            "pg_count": dist.get_pg_count(),
331*da0073e9SAndroid Build Coastguard Worker            "pg_config": dist.distributed_c10d._get_all_pg_configs(),
332*da0073e9SAndroid Build Coastguard Worker        }
333*da0073e9SAndroid Build Coastguard Worker        if backend == "nccl":
334*da0073e9SAndroid Build Coastguard Worker            nccl_version = torch.cuda.nccl.version()
335*da0073e9SAndroid Build Coastguard Worker            dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
336*da0073e9SAndroid Build Coastguard Worker        return dist_info
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    def _memory_profile(self) -> MemoryProfile:
339*da0073e9SAndroid Build Coastguard Worker        required = ("record_shapes", "profile_memory", "with_stack")
340*da0073e9SAndroid Build Coastguard Worker        missing = [f"{i}=True" for i in required if not getattr(self, i)]
341*da0073e9SAndroid Build Coastguard Worker        if missing:
342*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"{', '.join(missing)} required for memory profiling.")
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        assert self.profiler is not None and self.profiler.kineto_results is not None
345*da0073e9SAndroid Build Coastguard Worker        return MemoryProfile(self.profiler.kineto_results)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
348*da0073e9SAndroid Build Coastguard Worker        """Export memory event information from the profiler collected
349*da0073e9SAndroid Build Coastguard Worker        tree for a given device, and export a timeline plot. There are 3
350*da0073e9SAndroid Build Coastguard Worker        exportable files using ``export_memory_timeline``, each controlled by the
351*da0073e9SAndroid Build Coastguard Worker        ``path``'s suffix.
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
354*da0073e9SAndroid Build Coastguard Worker          plot will be embedded as a PNG file in the HTML file.
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        - For plot points consisting of ``[times, [sizes by category]]``, where
357*da0073e9SAndroid Build Coastguard Worker          ``times`` are timestamps and ``sizes`` are memory usage for each category.
358*da0073e9SAndroid Build Coastguard Worker          The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
359*da0073e9SAndroid Build Coastguard Worker          (``.json.gz``) depending on the suffix.
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
362*da0073e9SAndroid Build Coastguard Worker          event will consist of ``(timestamp, action, numbytes, category)``, where
363*da0073e9SAndroid Build Coastguard Worker          ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
364*da0073e9SAndroid Build Coastguard Worker          and ``category`` is one of the enums from
365*da0073e9SAndroid Build Coastguard Worker          ``torch.profiler._memory_profiler.Category``.
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        Output: Memory timeline written as gzipped JSON, JSON, or HTML.
368*da0073e9SAndroid Build Coastguard Worker        """
369*da0073e9SAndroid Build Coastguard Worker        # Default to device 0, if unset. Fallback on cpu.
370*da0073e9SAndroid Build Coastguard Worker        if device is None and self.use_device and self.use_device != "cuda":
371*da0073e9SAndroid Build Coastguard Worker            device = self.use_device + ":0"
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        if device is None:
374*da0073e9SAndroid Build Coastguard Worker            device = "cuda:0" if torch.cuda.is_available() else "cpu"
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        # Construct the memory timeline plot data
377*da0073e9SAndroid Build Coastguard Worker        self.mem_tl = MemoryProfileTimeline(self._memory_profile())
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        # Depending on the file suffix, save the data as json.gz or json.
380*da0073e9SAndroid Build Coastguard Worker        # For html, we can embed the image into an HTML file.
381*da0073e9SAndroid Build Coastguard Worker        if path.endswith(".html"):
382*da0073e9SAndroid Build Coastguard Worker            self.mem_tl.export_memory_timeline_html(path, device)
383*da0073e9SAndroid Build Coastguard Worker        elif path.endswith(".gz"):
384*da0073e9SAndroid Build Coastguard Worker            fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
385*da0073e9SAndroid Build Coastguard Worker            fp.close()
386*da0073e9SAndroid Build Coastguard Worker            if path.endswith("raw.json.gz"):
387*da0073e9SAndroid Build Coastguard Worker                self.mem_tl.export_memory_timeline_raw(fp.name, device)
388*da0073e9SAndroid Build Coastguard Worker            else:
389*da0073e9SAndroid Build Coastguard Worker                self.mem_tl.export_memory_timeline(fp.name, device)
390*da0073e9SAndroid Build Coastguard Worker            with open(fp.name) as fin:
391*da0073e9SAndroid Build Coastguard Worker                with gzip.open(path, "wt") as fout:
392*da0073e9SAndroid Build Coastguard Worker                    fout.writelines(fin)
393*da0073e9SAndroid Build Coastguard Worker            os.remove(fp.name)
394*da0073e9SAndroid Build Coastguard Worker        else:
395*da0073e9SAndroid Build Coastguard Worker            self.mem_tl.export_memory_timeline(path, device)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Workerclass ProfilerAction(Enum):
399*da0073e9SAndroid Build Coastguard Worker    """
400*da0073e9SAndroid Build Coastguard Worker    Profiler actions that can be taken at the specified intervals
401*da0073e9SAndroid Build Coastguard Worker    """
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    NONE = 0
404*da0073e9SAndroid Build Coastguard Worker    WARMUP = 1
405*da0073e9SAndroid Build Coastguard Worker    RECORD = 2
406*da0073e9SAndroid Build Coastguard Worker    RECORD_AND_SAVE = 3
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Workerdef schedule(
410*da0073e9SAndroid Build Coastguard Worker    *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0
411*da0073e9SAndroid Build Coastguard Worker) -> Callable:
412*da0073e9SAndroid Build Coastguard Worker    """
413*da0073e9SAndroid Build Coastguard Worker    Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
414*da0073e9SAndroid Build Coastguard Worker    the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
415*da0073e9SAndroid Build Coastguard Worker    then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
416*da0073e9SAndroid Build Coastguard Worker    The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
417*da0073e9SAndroid Build Coastguard Worker    the cycles will continue until the profiling is finished.
418*da0073e9SAndroid Build Coastguard Worker    """
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def schedule_fn(step: int) -> ProfilerAction:
421*da0073e9SAndroid Build Coastguard Worker        assert step >= 0
422*da0073e9SAndroid Build Coastguard Worker        if step < skip_first:
423*da0073e9SAndroid Build Coastguard Worker            return ProfilerAction.NONE
424*da0073e9SAndroid Build Coastguard Worker        else:
425*da0073e9SAndroid Build Coastguard Worker            step -= skip_first
426*da0073e9SAndroid Build Coastguard Worker        num_steps = wait + warmup + active
427*da0073e9SAndroid Build Coastguard Worker        if repeat > 0 and step / num_steps >= repeat:
428*da0073e9SAndroid Build Coastguard Worker            return ProfilerAction.NONE
429*da0073e9SAndroid Build Coastguard Worker        mod_step = step % num_steps
430*da0073e9SAndroid Build Coastguard Worker        if mod_step < wait:
431*da0073e9SAndroid Build Coastguard Worker            return ProfilerAction.NONE
432*da0073e9SAndroid Build Coastguard Worker        elif mod_step < wait + warmup:
433*da0073e9SAndroid Build Coastguard Worker            return ProfilerAction.WARMUP
434*da0073e9SAndroid Build Coastguard Worker        else:
435*da0073e9SAndroid Build Coastguard Worker            return (
436*da0073e9SAndroid Build Coastguard Worker                ProfilerAction.RECORD
437*da0073e9SAndroid Build Coastguard Worker                if mod_step < num_steps - 1
438*da0073e9SAndroid Build Coastguard Worker                else ProfilerAction.RECORD_AND_SAVE
439*da0073e9SAndroid Build Coastguard Worker            )
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    assert (
442*da0073e9SAndroid Build Coastguard Worker        wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
443*da0073e9SAndroid Build Coastguard Worker    ), "Invalid profiler schedule arguments"
444*da0073e9SAndroid Build Coastguard Worker    if warmup == 0:
445*da0073e9SAndroid Build Coastguard Worker        warn("Profiler won't be using warmup, this can skew profiler results")
446*da0073e9SAndroid Build Coastguard Worker    return schedule_fn
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Workerdef _default_schedule_fn(_: int) -> ProfilerAction:
450*da0073e9SAndroid Build Coastguard Worker    """
451*da0073e9SAndroid Build Coastguard Worker    Default profiler behavior - immediately starts recording the events,
452*da0073e9SAndroid Build Coastguard Worker    keeps doing it on every profiler step.
453*da0073e9SAndroid Build Coastguard Worker    """
454*da0073e9SAndroid Build Coastguard Worker    return ProfilerAction.RECORD
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Workerdef tensorboard_trace_handler(
458*da0073e9SAndroid Build Coastguard Worker    dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False
459*da0073e9SAndroid Build Coastguard Worker):
460*da0073e9SAndroid Build Coastguard Worker    """
461*da0073e9SAndroid Build Coastguard Worker    Outputs tracing files to directory of ``dir_name``, then that directory can be
462*da0073e9SAndroid Build Coastguard Worker    directly delivered to tensorboard as logdir.
463*da0073e9SAndroid Build Coastguard Worker    ``worker_name`` should be unique for each worker in distributed scenario,
464*da0073e9SAndroid Build Coastguard Worker    it will be set to '[hostname]_[pid]' by default.
465*da0073e9SAndroid Build Coastguard Worker    """
466*da0073e9SAndroid Build Coastguard Worker    import os
467*da0073e9SAndroid Build Coastguard Worker    import socket
468*da0073e9SAndroid Build Coastguard Worker    import time
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    def handler_fn(prof) -> None:
471*da0073e9SAndroid Build Coastguard Worker        nonlocal worker_name
472*da0073e9SAndroid Build Coastguard Worker        if not os.path.isdir(dir_name):
473*da0073e9SAndroid Build Coastguard Worker            try:
474*da0073e9SAndroid Build Coastguard Worker                os.makedirs(dir_name, exist_ok=True)
475*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
476*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Can't create directory: " + dir_name) from e
477*da0073e9SAndroid Build Coastguard Worker        if not worker_name:
478*da0073e9SAndroid Build Coastguard Worker            worker_name = f"{socket.gethostname()}_{os.getpid()}"
479*da0073e9SAndroid Build Coastguard Worker        # Use nanosecond here to avoid naming clash when exporting the trace
480*da0073e9SAndroid Build Coastguard Worker        file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
481*da0073e9SAndroid Build Coastguard Worker        if use_gzip:
482*da0073e9SAndroid Build Coastguard Worker            file_name = file_name + ".gz"
483*da0073e9SAndroid Build Coastguard Worker        prof.export_chrome_trace(os.path.join(dir_name, file_name))
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker    return handler_fn
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Workerclass profile(_KinetoProfile):
489*da0073e9SAndroid Build Coastguard Worker    """Profiler context manager.
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    Args:
492*da0073e9SAndroid Build Coastguard Worker        activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
493*da0073e9SAndroid Build Coastguard Worker            ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
494*da0073e9SAndroid Build Coastguard Worker            ``torch.profiler.ProfilerActivity.XPU``.
495*da0073e9SAndroid Build Coastguard Worker            Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
496*da0073e9SAndroid Build Coastguard Worker            or (when available) ProfilerActivity.XPU.
497*da0073e9SAndroid Build Coastguard Worker        schedule (Callable): callable that takes step (int) as a single parameter and returns
498*da0073e9SAndroid Build Coastguard Worker            ``ProfilerAction`` value that specifies the profiler action to perform at each step.
499*da0073e9SAndroid Build Coastguard Worker        on_trace_ready (Callable): callable that is called at each step when ``schedule``
500*da0073e9SAndroid Build Coastguard Worker            returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
501*da0073e9SAndroid Build Coastguard Worker        record_shapes (bool): save information about operator's input shapes.
502*da0073e9SAndroid Build Coastguard Worker        profile_memory (bool): track tensor memory allocation/deallocation.
503*da0073e9SAndroid Build Coastguard Worker        with_stack (bool): record source information (file and line number) for the ops.
504*da0073e9SAndroid Build Coastguard Worker        with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
505*da0073e9SAndroid Build Coastguard Worker            (matrix multiplication and 2D convolution).
506*da0073e9SAndroid Build Coastguard Worker        with_modules (bool): record module hierarchy (including function names)
507*da0073e9SAndroid Build Coastguard Worker            corresponding to the callstack of the op. e.g. If module A's forward call's
508*da0073e9SAndroid Build Coastguard Worker            module B's forward which contains an aten::add op,
509*da0073e9SAndroid Build Coastguard Worker            then aten::add's module hierarchy is A.B
510*da0073e9SAndroid Build Coastguard Worker            Note that this support exist, at the moment, only for TorchScript models
511*da0073e9SAndroid Build Coastguard Worker            and not eager mode models.
512*da0073e9SAndroid Build Coastguard Worker        experimental_config (_ExperimentalConfig) : A set of experimental options
513*da0073e9SAndroid Build Coastguard Worker            used for Kineto library features. Note, backward compatibility is not guaranteed.
514*da0073e9SAndroid Build Coastguard Worker        execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
515*da0073e9SAndroid Build Coastguard Worker            `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
516*da0073e9SAndroid Build Coastguard Worker            representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
517*da0073e9SAndroid Build Coastguard Worker            When this argument is included the observer start() and stop() will be called for the
518*da0073e9SAndroid Build Coastguard Worker            same time window as PyTorch profiler. See the examples section below for a code sample.
519*da0073e9SAndroid Build Coastguard Worker        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
520*da0073e9SAndroid Build Coastguard Worker        use_cuda (bool):
521*da0073e9SAndroid Build Coastguard Worker            .. deprecated:: 1.8.1
522*da0073e9SAndroid Build Coastguard Worker                use ``activities`` instead.
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    .. note::
525*da0073e9SAndroid Build Coastguard Worker        Use :func:`~torch.profiler.schedule` to generate the callable schedule.
526*da0073e9SAndroid Build Coastguard Worker        Non-default schedules are useful when profiling long training jobs
527*da0073e9SAndroid Build Coastguard Worker        and allow the user to obtain multiple traces at the different iterations
528*da0073e9SAndroid Build Coastguard Worker        of the training process.
529*da0073e9SAndroid Build Coastguard Worker        The default schedule simply records all the events continuously for the
530*da0073e9SAndroid Build Coastguard Worker        duration of the context manager.
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    .. note::
533*da0073e9SAndroid Build Coastguard Worker        Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        After profiling, result files can be found in the specified directory. Use the command:
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        ``tensorboard --logdir dir_name``
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        to see the results in TensorBoard.
542*da0073e9SAndroid Build Coastguard Worker        For more information, see
543*da0073e9SAndroid Build Coastguard Worker        `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker    .. note::
546*da0073e9SAndroid Build Coastguard Worker        Enabling shape and stack tracing results in additional overhead.
547*da0073e9SAndroid Build Coastguard Worker        When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
548*da0073e9SAndroid Build Coastguard Worker        that may further prevent certain optimizations that depend on the reference count and introduce
549*da0073e9SAndroid Build Coastguard Worker        extra tensor copies.
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    Examples:
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    .. code-block:: python
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(
557*da0073e9SAndroid Build Coastguard Worker            activities=[
558*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CPU,
559*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CUDA,
560*da0073e9SAndroid Build Coastguard Worker            ]
561*da0073e9SAndroid Build Coastguard Worker        ) as p:
562*da0073e9SAndroid Build Coastguard Worker            code_to_profile()
563*da0073e9SAndroid Build Coastguard Worker        print(p.key_averages().table(
564*da0073e9SAndroid Build Coastguard Worker            sort_by="self_cuda_time_total", row_limit=-1))
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    .. code-block:: python
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        # Non-default profiler schedule allows user to turn profiler on and off
571*da0073e9SAndroid Build Coastguard Worker        # on different iterations of the training loop;
572*da0073e9SAndroid Build Coastguard Worker        # trace_handler is called every time a new trace becomes available
573*da0073e9SAndroid Build Coastguard Worker        def trace_handler(prof):
574*da0073e9SAndroid Build Coastguard Worker            print(prof.key_averages().table(
575*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cuda_time_total", row_limit=-1))
576*da0073e9SAndroid Build Coastguard Worker            # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(
579*da0073e9SAndroid Build Coastguard Worker            activities=[
580*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CPU,
581*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CUDA,
582*da0073e9SAndroid Build Coastguard Worker            ],
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker            # In this example with wait=1, warmup=1, active=2, repeat=1,
585*da0073e9SAndroid Build Coastguard Worker            # profiler will skip the first step/iteration,
586*da0073e9SAndroid Build Coastguard Worker            # start warming up on the second, record
587*da0073e9SAndroid Build Coastguard Worker            # the third and the forth iterations,
588*da0073e9SAndroid Build Coastguard Worker            # after which the trace will become available
589*da0073e9SAndroid Build Coastguard Worker            # and on_trace_ready (when set) is called;
590*da0073e9SAndroid Build Coastguard Worker            # the cycle repeats starting with the next step
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(
593*da0073e9SAndroid Build Coastguard Worker                wait=1,
594*da0073e9SAndroid Build Coastguard Worker                warmup=1,
595*da0073e9SAndroid Build Coastguard Worker                active=2,
596*da0073e9SAndroid Build Coastguard Worker                repeat=1),
597*da0073e9SAndroid Build Coastguard Worker            on_trace_ready=trace_handler
598*da0073e9SAndroid Build Coastguard Worker            # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
599*da0073e9SAndroid Build Coastguard Worker            # used when outputting for tensorboard
600*da0073e9SAndroid Build Coastguard Worker            ) as p:
601*da0073e9SAndroid Build Coastguard Worker                for iter in range(N):
602*da0073e9SAndroid Build Coastguard Worker                    code_iteration_to_profile(iter)
603*da0073e9SAndroid Build Coastguard Worker                    # send a signal to the profiler that the next iteration has started
604*da0073e9SAndroid Build Coastguard Worker                    p.step()
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker    .. code-block:: python
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(
611*da0073e9SAndroid Build Coastguard Worker            ...
612*da0073e9SAndroid Build Coastguard Worker            execution_trace_observer=(
613*da0073e9SAndroid Build Coastguard Worker                ExecutionTraceObserver().register_callback("./execution_trace.json")
614*da0073e9SAndroid Build Coastguard Worker            ),
615*da0073e9SAndroid Build Coastguard Worker        ) as p:
616*da0073e9SAndroid Build Coastguard Worker            for iter in range(N):
617*da0073e9SAndroid Build Coastguard Worker                code_iteration_to_profile(iter)
618*da0073e9SAndroid Build Coastguard Worker                p.step()
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py.
621*da0073e9SAndroid Build Coastguard Worker    Note: One can also pass any object satisfying the _ITraceObserver interface.
622*da0073e9SAndroid Build Coastguard Worker    """
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker    def __init__(
625*da0073e9SAndroid Build Coastguard Worker        self,
626*da0073e9SAndroid Build Coastguard Worker        *,
627*da0073e9SAndroid Build Coastguard Worker        activities: Optional[Iterable[ProfilerActivity]] = None,
628*da0073e9SAndroid Build Coastguard Worker        schedule: Optional[Callable[[int], ProfilerAction]] = None,
629*da0073e9SAndroid Build Coastguard Worker        on_trace_ready: Optional[Callable[..., Any]] = None,
630*da0073e9SAndroid Build Coastguard Worker        record_shapes: bool = False,
631*da0073e9SAndroid Build Coastguard Worker        profile_memory: bool = False,
632*da0073e9SAndroid Build Coastguard Worker        with_stack: bool = False,
633*da0073e9SAndroid Build Coastguard Worker        with_flops: bool = False,
634*da0073e9SAndroid Build Coastguard Worker        with_modules: bool = False,
635*da0073e9SAndroid Build Coastguard Worker        experimental_config: Optional[_ExperimentalConfig] = None,
636*da0073e9SAndroid Build Coastguard Worker        execution_trace_observer: Optional[_ITraceObserver] = None,
637*da0073e9SAndroid Build Coastguard Worker        acc_events: bool = False,
638*da0073e9SAndroid Build Coastguard Worker        # deprecated:
639*da0073e9SAndroid Build Coastguard Worker        use_cuda: Optional[bool] = None,
640*da0073e9SAndroid Build Coastguard Worker    ):
641*da0073e9SAndroid Build Coastguard Worker        activities_set = set(activities) if activities else supported_activities()
642*da0073e9SAndroid Build Coastguard Worker        if use_cuda is not None:
643*da0073e9SAndroid Build Coastguard Worker            warn(
644*da0073e9SAndroid Build Coastguard Worker                "`use_cuda` is deprecated, use `activities` argument instead",
645*da0073e9SAndroid Build Coastguard Worker                FutureWarning,
646*da0073e9SAndroid Build Coastguard Worker                stacklevel=2,
647*da0073e9SAndroid Build Coastguard Worker            )
648*da0073e9SAndroid Build Coastguard Worker            if use_cuda:
649*da0073e9SAndroid Build Coastguard Worker                activities_set.add(ProfilerActivity.CUDA)
650*da0073e9SAndroid Build Coastguard Worker            elif ProfilerActivity.CUDA in activities_set:
651*da0073e9SAndroid Build Coastguard Worker                activities_set.remove(ProfilerActivity.CUDA)
652*da0073e9SAndroid Build Coastguard Worker        assert len(activities_set) > 0, "No valid profiler activities found"
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker        super().__init__(
655*da0073e9SAndroid Build Coastguard Worker            activities=activities,
656*da0073e9SAndroid Build Coastguard Worker            record_shapes=record_shapes,
657*da0073e9SAndroid Build Coastguard Worker            profile_memory=profile_memory,
658*da0073e9SAndroid Build Coastguard Worker            with_stack=with_stack,
659*da0073e9SAndroid Build Coastguard Worker            with_flops=with_flops,
660*da0073e9SAndroid Build Coastguard Worker            with_modules=with_modules,
661*da0073e9SAndroid Build Coastguard Worker            experimental_config=experimental_config,
662*da0073e9SAndroid Build Coastguard Worker            execution_trace_observer=execution_trace_observer,
663*da0073e9SAndroid Build Coastguard Worker            acc_events=acc_events,
664*da0073e9SAndroid Build Coastguard Worker        )
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        if schedule:
667*da0073e9SAndroid Build Coastguard Worker            self.schedule = schedule
668*da0073e9SAndroid Build Coastguard Worker            # add step markers into the trace and table view
669*da0073e9SAndroid Build Coastguard Worker            self.record_steps = True
670*da0073e9SAndroid Build Coastguard Worker        else:
671*da0073e9SAndroid Build Coastguard Worker            self.schedule = _default_schedule_fn
672*da0073e9SAndroid Build Coastguard Worker            self.record_steps = False
673*da0073e9SAndroid Build Coastguard Worker        self.on_trace_ready = on_trace_ready
674*da0073e9SAndroid Build Coastguard Worker        self.step_num = 0
675*da0073e9SAndroid Build Coastguard Worker        self.current_action = self.schedule(self.step_num)
676*da0073e9SAndroid Build Coastguard Worker        self.step_rec_fn: Optional[prof.record_function] = None
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        self.action_map: Dict[
679*da0073e9SAndroid Build Coastguard Worker            Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]
680*da0073e9SAndroid Build Coastguard Worker        ] = {
681*da0073e9SAndroid Build Coastguard Worker            # key is (prev_action, current_action), value is action list corresponding to the state pair.
682*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.NONE, ProfilerAction.NONE): [],
683*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
684*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.NONE, ProfilerAction.RECORD): [
685*da0073e9SAndroid Build Coastguard Worker                self.prepare_trace,
686*da0073e9SAndroid Build Coastguard Worker                self.start_trace,
687*da0073e9SAndroid Build Coastguard Worker            ],
688*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [
689*da0073e9SAndroid Build Coastguard Worker                self.prepare_trace,
690*da0073e9SAndroid Build Coastguard Worker                self.start_trace,
691*da0073e9SAndroid Build Coastguard Worker            ],
692*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.WARMUP, ProfilerAction.NONE): [
693*da0073e9SAndroid Build Coastguard Worker                partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
694*da0073e9SAndroid Build Coastguard Worker                self.start_trace,
695*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
696*da0073e9SAndroid Build Coastguard Worker            ],
697*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
698*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
699*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
700*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD, ProfilerAction.NONE): [
701*da0073e9SAndroid Build Coastguard Worker                partial(warn, "Incorrect schedule: RECORD followed by NONE"),
702*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
703*da0073e9SAndroid Build Coastguard Worker            ],
704*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
705*da0073e9SAndroid Build Coastguard Worker                partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
706*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
707*da0073e9SAndroid Build Coastguard Worker            ],
708*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
709*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
710*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [
711*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
712*da0073e9SAndroid Build Coastguard Worker                self._trace_ready,
713*da0073e9SAndroid Build Coastguard Worker            ],
714*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [
715*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
716*da0073e9SAndroid Build Coastguard Worker                self._trace_ready,
717*da0073e9SAndroid Build Coastguard Worker                self.prepare_trace,
718*da0073e9SAndroid Build Coastguard Worker            ],
719*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
720*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
721*da0073e9SAndroid Build Coastguard Worker                self._trace_ready,
722*da0073e9SAndroid Build Coastguard Worker                self.prepare_trace,
723*da0073e9SAndroid Build Coastguard Worker                self.start_trace,
724*da0073e9SAndroid Build Coastguard Worker            ],
725*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
726*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
727*da0073e9SAndroid Build Coastguard Worker                self._trace_ready,
728*da0073e9SAndroid Build Coastguard Worker                self.prepare_trace,
729*da0073e9SAndroid Build Coastguard Worker                self.start_trace,
730*da0073e9SAndroid Build Coastguard Worker            ],
731*da0073e9SAndroid Build Coastguard Worker            # used for exit action
732*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
733*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
734*da0073e9SAndroid Build Coastguard Worker            (ProfilerAction.RECORD_AND_SAVE, None): [
735*da0073e9SAndroid Build Coastguard Worker                self.stop_trace,
736*da0073e9SAndroid Build Coastguard Worker                self._trace_ready,
737*da0073e9SAndroid Build Coastguard Worker            ],
738*da0073e9SAndroid Build Coastguard Worker        }
739*da0073e9SAndroid Build Coastguard Worker        # Start tracking increments to profiler step, this will be used
740*da0073e9SAndroid Build Coastguard Worker        # by Kineto
741*da0073e9SAndroid Build Coastguard Worker        prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
744*da0073e9SAndroid Build Coastguard Worker        self.start()
745*da0073e9SAndroid Build Coastguard Worker        return self
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_val, exc_tb):
748*da0073e9SAndroid Build Coastguard Worker        self.stop()
749*da0073e9SAndroid Build Coastguard Worker        prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
750*da0073e9SAndroid Build Coastguard Worker        if self.execution_trace_observer:
751*da0073e9SAndroid Build Coastguard Worker            self.execution_trace_observer.cleanup()
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker    def start(self):
754*da0073e9SAndroid Build Coastguard Worker        self._transit_action(ProfilerAction.NONE, self.current_action)
755*da0073e9SAndroid Build Coastguard Worker        if self.record_steps:
756*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn = prof.record_function(
757*da0073e9SAndroid Build Coastguard Worker                "ProfilerStep#" + str(self.step_num)
758*da0073e9SAndroid Build Coastguard Worker            )
759*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn.__enter__()
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker    def stop(self):
762*da0073e9SAndroid Build Coastguard Worker        if self.record_steps and self.step_rec_fn:
763*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn.__exit__(None, None, None)
764*da0073e9SAndroid Build Coastguard Worker        self._transit_action(self.current_action, None)
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    def step(self):
767*da0073e9SAndroid Build Coastguard Worker        """
768*da0073e9SAndroid Build Coastguard Worker        Signals the profiler that the next profiling step has started.
769*da0073e9SAndroid Build Coastguard Worker        """
770*da0073e9SAndroid Build Coastguard Worker        if self.record_steps and self.step_rec_fn:
771*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn.__exit__(None, None, None)
772*da0073e9SAndroid Build Coastguard Worker        prev_action = self.current_action
773*da0073e9SAndroid Build Coastguard Worker        self.step_num += 1
774*da0073e9SAndroid Build Coastguard Worker        self.current_action = self.schedule(self.step_num)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        self._transit_action(prev_action, self.current_action)
777*da0073e9SAndroid Build Coastguard Worker        prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        if self.record_steps:
780*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn = prof.record_function(
781*da0073e9SAndroid Build Coastguard Worker                "ProfilerStep#" + str(self.step_num)
782*da0073e9SAndroid Build Coastguard Worker            )
783*da0073e9SAndroid Build Coastguard Worker            self.step_rec_fn.__enter__()
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    def _trace_ready(self):
786*da0073e9SAndroid Build Coastguard Worker        if self.on_trace_ready:
787*da0073e9SAndroid Build Coastguard Worker            self.on_trace_ready(self)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker    def _transit_action(self, prev_action, current_action):
790*da0073e9SAndroid Build Coastguard Worker        action_list = self.action_map.get((prev_action, current_action))
791*da0073e9SAndroid Build Coastguard Worker        if action_list:
792*da0073e9SAndroid Build Coastguard Worker            for action in action_list:
793*da0073e9SAndroid Build Coastguard Worker                action()
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker    def _stats(self) -> Optional[prof._ProfilerStats]:
796*da0073e9SAndroid Build Coastguard Worker        if self.profiler is None:
797*da0073e9SAndroid Build Coastguard Worker            return None
798*da0073e9SAndroid Build Coastguard Worker        return self.profiler._stats
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Workerclass ExecutionTraceObserver(_ITraceObserver):
802*da0073e9SAndroid Build Coastguard Worker    """Execution Trace Observer
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker    Each process can have a single ExecutionTraceObserver instance. The observer
805*da0073e9SAndroid Build Coastguard Worker    can be added to record function callbacks via calling register_callback()
806*da0073e9SAndroid Build Coastguard Worker    explicitly. Without calling unregister_callback(), repeated calls to
807*da0073e9SAndroid Build Coastguard Worker    register_callback() will not add additional observers to record function
808*da0073e9SAndroid Build Coastguard Worker    callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
809*da0073e9SAndroid Build Coastguard Worker    methods control when the event data is recorded.
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    Deleting or calling unregister_callback() will remove the observer from the
812*da0073e9SAndroid Build Coastguard Worker    record function callbacks, finalize the output file, and will stop
813*da0073e9SAndroid Build Coastguard Worker    incurring any overheads.
814*da0073e9SAndroid Build Coastguard Worker    """
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
817*da0073e9SAndroid Build Coastguard Worker        """
818*da0073e9SAndroid Build Coastguard Worker        Initializes the default states.
819*da0073e9SAndroid Build Coastguard Worker        """
820*da0073e9SAndroid Build Coastguard Worker        self._registered = False
821*da0073e9SAndroid Build Coastguard Worker        self._execution_trace_running = False
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker    def __del__(self):
824*da0073e9SAndroid Build Coastguard Worker        """
825*da0073e9SAndroid Build Coastguard Worker        Calls unregister_callback() to make sure to finalize outputs.
826*da0073e9SAndroid Build Coastguard Worker        """
827*da0073e9SAndroid Build Coastguard Worker        self.unregister_callback()
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker    def register_callback(self, output_file_path: str) -> Self:
830*da0073e9SAndroid Build Coastguard Worker        """
831*da0073e9SAndroid Build Coastguard Worker        Adds ET observer to record function callbacks. The data will be
832*da0073e9SAndroid Build Coastguard Worker        written to output_file_path.
833*da0073e9SAndroid Build Coastguard Worker        """
834*da0073e9SAndroid Build Coastguard Worker        if not self._registered:
835*da0073e9SAndroid Build Coastguard Worker            self._output_file_path = output_file_path
836*da0073e9SAndroid Build Coastguard Worker            self._registered = _add_execution_trace_observer(output_file_path)
837*da0073e9SAndroid Build Coastguard Worker        return self
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker    def unregister_callback(self):
840*da0073e9SAndroid Build Coastguard Worker        """
841*da0073e9SAndroid Build Coastguard Worker        Removes ET observer from record function callbacks.
842*da0073e9SAndroid Build Coastguard Worker        """
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        def _save_triton_kernels():
845*da0073e9SAndroid Build Coastguard Worker            # Save the kernel paths for the generated kernels
846*da0073e9SAndroid Build Coastguard Worker            from torch._inductor.codecache import PyCodeCache as PyCodeCache
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker            kernel_files = [
849*da0073e9SAndroid Build Coastguard Worker                v.__file__
850*da0073e9SAndroid Build Coastguard Worker                for v in PyCodeCache.cache.values()
851*da0073e9SAndroid Build Coastguard Worker                if getattr(v, "__file__", None) is not None
852*da0073e9SAndroid Build Coastguard Worker            ]
853*da0073e9SAndroid Build Coastguard Worker            work_dir, file_name = os.path.split(self._output_file_path)
854*da0073e9SAndroid Build Coastguard Worker            resource_dir = os.path.join(
855*da0073e9SAndroid Build Coastguard Worker                work_dir, os.path.splitext(file_name)[0] + "_resources"
856*da0073e9SAndroid Build Coastguard Worker            )
857*da0073e9SAndroid Build Coastguard Worker            if not os.path.exists(resource_dir):
858*da0073e9SAndroid Build Coastguard Worker                os.mkdir(resource_dir)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker            for kernel_file in kernel_files:
861*da0073e9SAndroid Build Coastguard Worker                if kernel_file is None:
862*da0073e9SAndroid Build Coastguard Worker                    continue
863*da0073e9SAndroid Build Coastguard Worker                path, name = os.path.split(kernel_file)
864*da0073e9SAndroid Build Coastguard Worker                dst = os.path.join(resource_dir, name)
865*da0073e9SAndroid Build Coastguard Worker                shutil.copyfile(kernel_file, dst)
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker        if self._registered:
868*da0073e9SAndroid Build Coastguard Worker            self.stop()
869*da0073e9SAndroid Build Coastguard Worker            try:
870*da0073e9SAndroid Build Coastguard Worker                _save_triton_kernels()
871*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
872*da0073e9SAndroid Build Coastguard Worker                warn(f"Execution trace failed to save kernels: {e}")
873*da0073e9SAndroid Build Coastguard Worker            _remove_execution_trace_observer()
874*da0073e9SAndroid Build Coastguard Worker            self._registered = False
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker    @property
877*da0073e9SAndroid Build Coastguard Worker    def is_registered(self):
878*da0073e9SAndroid Build Coastguard Worker        """
879*da0073e9SAndroid Build Coastguard Worker        Returns True if the execution trace observer is registered, otherwise False.
880*da0073e9SAndroid Build Coastguard Worker        """
881*da0073e9SAndroid Build Coastguard Worker        return self._registered
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker    def is_running(self):
884*da0073e9SAndroid Build Coastguard Worker        """
885*da0073e9SAndroid Build Coastguard Worker        Returns True if the observer is running, otherwise False.
886*da0073e9SAndroid Build Coastguard Worker        """
887*da0073e9SAndroid Build Coastguard Worker        return self._execution_trace_running
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker    def start(self):
890*da0073e9SAndroid Build Coastguard Worker        """
891*da0073e9SAndroid Build Coastguard Worker        Starts to capture.
892*da0073e9SAndroid Build Coastguard Worker        """
893*da0073e9SAndroid Build Coastguard Worker        if self._registered and not self._execution_trace_running:
894*da0073e9SAndroid Build Coastguard Worker            _enable_execution_trace_observer()
895*da0073e9SAndroid Build Coastguard Worker            self._execution_trace_running = True
896*da0073e9SAndroid Build Coastguard Worker            self._record_pg_config()
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    def stop(self):
899*da0073e9SAndroid Build Coastguard Worker        """
900*da0073e9SAndroid Build Coastguard Worker        Stops to capture.
901*da0073e9SAndroid Build Coastguard Worker        """
902*da0073e9SAndroid Build Coastguard Worker        if self._execution_trace_running:
903*da0073e9SAndroid Build Coastguard Worker            _disable_execution_trace_observer()
904*da0073e9SAndroid Build Coastguard Worker            self._execution_trace_running = False
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker    def cleanup(self):
907*da0073e9SAndroid Build Coastguard Worker        """
908*da0073e9SAndroid Build Coastguard Worker        Calls unregister_callback() to make sure to finalize outputs.
909*da0073e9SAndroid Build Coastguard Worker        """
910*da0073e9SAndroid Build Coastguard Worker        self.unregister_callback()
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker    def get_output_file_path(self) -> str:
913*da0073e9SAndroid Build Coastguard Worker        """
914*da0073e9SAndroid Build Coastguard Worker        Returns the output file name.
915*da0073e9SAndroid Build Coastguard Worker        """
916*da0073e9SAndroid Build Coastguard Worker        if self.is_registered:
917*da0073e9SAndroid Build Coastguard Worker            return self._output_file_path
918*da0073e9SAndroid Build Coastguard Worker        else:
919*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
920*da0073e9SAndroid Build Coastguard Worker                "A callback to the ET profiler needs to be registered "
921*da0073e9SAndroid Build Coastguard Worker                "first before getting the output file path"
922*da0073e9SAndroid Build Coastguard Worker            )
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    def _record_pg_config(self) -> None:
925*da0073e9SAndroid Build Coastguard Worker        # Records the PG config info to the trace as node:
926*da0073e9SAndroid Build Coastguard Worker        #  ## process_group:init ##
927*da0073e9SAndroid Build Coastguard Worker        if (
928*da0073e9SAndroid Build Coastguard Worker            self.is_registered
929*da0073e9SAndroid Build Coastguard Worker            and torch.distributed.is_available()
930*da0073e9SAndroid Build Coastguard Worker            and torch.distributed.is_initialized()
931*da0073e9SAndroid Build Coastguard Worker        ):
932*da0073e9SAndroid Build Coastguard Worker            pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
933*da0073e9SAndroid Build Coastguard Worker            torch.autograd._record_function_with_args_enter(
934*da0073e9SAndroid Build Coastguard Worker                "## process_group:init ##", json.dumps(pg_config_info)
935*da0073e9SAndroid Build Coastguard Worker            )
936