1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport gzip 3*da0073e9SAndroid Build Coastguard Workerimport json 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport shutil 6*da0073e9SAndroid Build Coastguard Workerimport tempfile 7*da0073e9SAndroid Build Coastguard Workerfrom abc import ABC, abstractmethod 8*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 9*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple 11*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self 12*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.profiler as prof 16*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _get_privateuse1_backend_name 17*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import ( 18*da0073e9SAndroid Build Coastguard Worker _add_execution_trace_observer, 19*da0073e9SAndroid Build Coastguard Worker _disable_execution_trace_observer, 20*da0073e9SAndroid Build Coastguard Worker _enable_execution_trace_observer, 21*da0073e9SAndroid Build Coastguard Worker _ExperimentalConfig, 22*da0073e9SAndroid Build Coastguard Worker _remove_execution_trace_observer, 23*da0073e9SAndroid Build Coastguard Worker) 24*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import kineto_available, ProfilerActivity 25*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker__all__ = [ 29*da0073e9SAndroid Build Coastguard Worker "supported_activities", 30*da0073e9SAndroid Build Coastguard Worker "ProfilerAction", 31*da0073e9SAndroid Build Coastguard Worker "schedule", 32*da0073e9SAndroid Build Coastguard Worker "tensorboard_trace_handler", 33*da0073e9SAndroid Build Coastguard Worker "profile", 34*da0073e9SAndroid Build Coastguard Worker "ExecutionTraceObserver", 35*da0073e9SAndroid Build Coastguard Worker] 36*da0073e9SAndroid Build Coastguard WorkerPROFILER_STEP_NAME = "ProfilerStep" 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerdef supported_activities(): 40*da0073e9SAndroid Build Coastguard Worker """ 41*da0073e9SAndroid Build Coastguard Worker Returns a set of supported profiler tracing activities. 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker Note: profiler uses CUPTI library to trace on-device CUDA kernels. 44*da0073e9SAndroid Build Coastguard Worker In case when CUDA is enabled but CUPTI is not available, passing 45*da0073e9SAndroid Build Coastguard Worker ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA 46*da0073e9SAndroid Build Coastguard Worker profiling code (same as in the legacy ``torch.autograd.profiler``). 47*da0073e9SAndroid Build Coastguard Worker This, in turn, results in including CUDA time in the profiler table output, 48*da0073e9SAndroid Build Coastguard Worker but not in the JSON trace. 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker return torch.autograd._supported_activities() 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerclass _ITraceObserver(ABC): 54*da0073e9SAndroid Build Coastguard Worker """Abstract interface for a Trace observer. 55*da0073e9SAndroid Build Coastguard Worker This satisfies 3 methods: start, stop and cleanup""" 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker @abstractmethod 58*da0073e9SAndroid Build Coastguard Worker def start(self): 59*da0073e9SAndroid Build Coastguard Worker pass 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @abstractmethod 62*da0073e9SAndroid Build Coastguard Worker def stop(self): 63*da0073e9SAndroid Build Coastguard Worker pass 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker @abstractmethod 66*da0073e9SAndroid Build Coastguard Worker def cleanup(self): 67*da0073e9SAndroid Build Coastguard Worker pass 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerclass _KinetoProfile: 71*da0073e9SAndroid Build Coastguard Worker """Low-level profiler wrap the autograd profile 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker Args: 74*da0073e9SAndroid Build Coastguard Worker activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: 75*da0073e9SAndroid Build Coastguard Worker ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, 76*da0073e9SAndroid Build Coastguard Worker ``torch.profiler.ProfilerActivity.XPU``. 77*da0073e9SAndroid Build Coastguard Worker Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA 78*da0073e9SAndroid Build Coastguard Worker or (when available) ProfilerActivity.XPU. 79*da0073e9SAndroid Build Coastguard Worker record_shapes (bool): save information about operator's input shapes. 80*da0073e9SAndroid Build Coastguard Worker profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline`` 81*da0073e9SAndroid Build Coastguard Worker for more details). 82*da0073e9SAndroid Build Coastguard Worker with_stack (bool): record source information (file and line number) for the ops. 83*da0073e9SAndroid Build Coastguard Worker with_flops (bool): use formula to estimate the FLOPS of specific operators 84*da0073e9SAndroid Build Coastguard Worker (matrix multiplication and 2D convolution). 85*da0073e9SAndroid Build Coastguard Worker with_modules (bool): record module hierarchy (including function names) 86*da0073e9SAndroid Build Coastguard Worker corresponding to the callstack of the op. e.g. If module A's forward call's 87*da0073e9SAndroid Build Coastguard Worker module B's forward which contains an aten::add op, 88*da0073e9SAndroid Build Coastguard Worker then aten::add's module hierarchy is A.B 89*da0073e9SAndroid Build Coastguard Worker Note that this support exist, at the moment, only for TorchScript models 90*da0073e9SAndroid Build Coastguard Worker and not eager mode models. 91*da0073e9SAndroid Build Coastguard Worker experimental_config (_ExperimentalConfig) : A set of experimental options 92*da0073e9SAndroid Build Coastguard Worker used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. 93*da0073e9SAndroid Build Coastguard Worker execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. 94*da0073e9SAndroid Build Coastguard Worker `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based 95*da0073e9SAndroid Build Coastguard Worker representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. 96*da0073e9SAndroid Build Coastguard Worker When this argument is included the observer start() and stop() will be called for the 97*da0073e9SAndroid Build Coastguard Worker same time window as PyTorch profiler. 98*da0073e9SAndroid Build Coastguard Worker acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker .. note:: 102*da0073e9SAndroid Build Coastguard Worker This API is experimental and subject to change in the future. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker Enabling shape and stack tracing results in additional overhead. 105*da0073e9SAndroid Build Coastguard Worker When record_shapes=True is specified, profiler will temporarily hold references to the tensors; 106*da0073e9SAndroid Build Coastguard Worker that may further prevent certain optimizations that depend on the reference count and introduce 107*da0073e9SAndroid Build Coastguard Worker extra tensor copies. 108*da0073e9SAndroid Build Coastguard Worker """ 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def __init__( 111*da0073e9SAndroid Build Coastguard Worker self, 112*da0073e9SAndroid Build Coastguard Worker *, 113*da0073e9SAndroid Build Coastguard Worker activities: Optional[Iterable[ProfilerActivity]] = None, 114*da0073e9SAndroid Build Coastguard Worker record_shapes: bool = False, 115*da0073e9SAndroid Build Coastguard Worker profile_memory: bool = False, 116*da0073e9SAndroid Build Coastguard Worker with_stack: bool = False, 117*da0073e9SAndroid Build Coastguard Worker with_flops: bool = False, 118*da0073e9SAndroid Build Coastguard Worker with_modules: bool = False, 119*da0073e9SAndroid Build Coastguard Worker experimental_config: Optional[_ExperimentalConfig] = None, 120*da0073e9SAndroid Build Coastguard Worker execution_trace_observer: Optional[_ITraceObserver] = None, 121*da0073e9SAndroid Build Coastguard Worker acc_events: bool = False, 122*da0073e9SAndroid Build Coastguard Worker ): 123*da0073e9SAndroid Build Coastguard Worker self.activities = set(activities) if activities else supported_activities() 124*da0073e9SAndroid Build Coastguard Worker self.record_shapes = record_shapes 125*da0073e9SAndroid Build Coastguard Worker self.with_flops = with_flops 126*da0073e9SAndroid Build Coastguard Worker self.profile_memory = profile_memory 127*da0073e9SAndroid Build Coastguard Worker self.with_stack = with_stack 128*da0073e9SAndroid Build Coastguard Worker self.with_modules = with_modules 129*da0073e9SAndroid Build Coastguard Worker self.experimental_config = experimental_config 130*da0073e9SAndroid Build Coastguard Worker self.execution_trace_observer = execution_trace_observer 131*da0073e9SAndroid Build Coastguard Worker self.acc_events = acc_events 132*da0073e9SAndroid Build Coastguard Worker self.profiler: Optional[prof.profile] = None 133*da0073e9SAndroid Build Coastguard Worker self.mem_tl: Optional[MemoryProfileTimeline] = None 134*da0073e9SAndroid Build Coastguard Worker self.use_device = None 135*da0073e9SAndroid Build Coastguard Worker if ProfilerActivity.CUDA in self.activities: 136*da0073e9SAndroid Build Coastguard Worker self.use_device = "cuda" 137*da0073e9SAndroid Build Coastguard Worker elif ProfilerActivity.XPU in self.activities: 138*da0073e9SAndroid Build Coastguard Worker self.use_device = "xpu" 139*da0073e9SAndroid Build Coastguard Worker elif ProfilerActivity.MTIA in self.activities: 140*da0073e9SAndroid Build Coastguard Worker self.use_device = "mtia" 141*da0073e9SAndroid Build Coastguard Worker elif ProfilerActivity.PrivateUse1 in self.activities: 142*da0073e9SAndroid Build Coastguard Worker self.use_device = _get_privateuse1_backend_name() 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # user-defined metadata to be amended to the trace 145*da0073e9SAndroid Build Coastguard Worker self.preset_metadata: Dict[str, str] = {} 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker def start(self): 148*da0073e9SAndroid Build Coastguard Worker self.prepare_trace() 149*da0073e9SAndroid Build Coastguard Worker self.start_trace() 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def stop(self): 152*da0073e9SAndroid Build Coastguard Worker self.stop_trace() 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker def prepare_trace(self): 155*da0073e9SAndroid Build Coastguard Worker if (self.profiler is None) or (not self.acc_events): 156*da0073e9SAndroid Build Coastguard Worker self.profiler = prof.profile( 157*da0073e9SAndroid Build Coastguard Worker use_cpu=(ProfilerActivity.CPU in self.activities), 158*da0073e9SAndroid Build Coastguard Worker use_device=self.use_device, 159*da0073e9SAndroid Build Coastguard Worker record_shapes=self.record_shapes, 160*da0073e9SAndroid Build Coastguard Worker with_flops=self.with_flops, 161*da0073e9SAndroid Build Coastguard Worker profile_memory=self.profile_memory, 162*da0073e9SAndroid Build Coastguard Worker with_stack=self.with_stack, 163*da0073e9SAndroid Build Coastguard Worker with_modules=self.with_modules, 164*da0073e9SAndroid Build Coastguard Worker use_kineto=True, 165*da0073e9SAndroid Build Coastguard Worker experimental_config=self.experimental_config, 166*da0073e9SAndroid Build Coastguard Worker acc_events=self.acc_events, 167*da0073e9SAndroid Build Coastguard Worker ) 168*da0073e9SAndroid Build Coastguard Worker self.profiler._prepare_trace() 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def start_trace(self): 171*da0073e9SAndroid Build Coastguard Worker if self.execution_trace_observer: 172*da0073e9SAndroid Build Coastguard Worker self.execution_trace_observer.start() 173*da0073e9SAndroid Build Coastguard Worker assert self.profiler is not None 174*da0073e9SAndroid Build Coastguard Worker self.profiler._start_trace() 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker if self.profile_memory: 177*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("profile_memory", "1") 178*da0073e9SAndroid Build Coastguard Worker if self.with_stack: 179*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("with_stack", "1") 180*da0073e9SAndroid Build Coastguard Worker if self.record_shapes: 181*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("record_shapes", "1") 182*da0073e9SAndroid Build Coastguard Worker if self.with_modules: 183*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("with_modules", "1") 184*da0073e9SAndroid Build Coastguard Worker if self.with_flops: 185*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("with_flops", "1") 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker if kineto_available(): 188*da0073e9SAndroid Build Coastguard Worker dist_info = self._get_distributed_info() 189*da0073e9SAndroid Build Coastguard Worker if dist_info: 190*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("distributedInfo", json.dumps(dist_info)) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, "_inductor"): 193*da0073e9SAndroid Build Coastguard Worker import torch._inductor.config as inductor_config 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker if inductor_config.triton.cudagraphs: 196*da0073e9SAndroid Build Coastguard Worker os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" 197*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") 198*da0073e9SAndroid Build Coastguard Worker # FIXME: CUDA Graph does not work well with CUPTI teardown. 199*da0073e9SAndroid Build Coastguard Worker # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) 200*da0073e9SAndroid Build Coastguard Worker # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) 201*da0073e9SAndroid Build Coastguard Worker # Workaround: turn off CUPTI teardown when using CUDA Graphs. 202*da0073e9SAndroid Build Coastguard Worker os.environ["TEARDOWN_CUPTI"] = "0" 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker # Insert the preset user metadata to the trace 205*da0073e9SAndroid Build Coastguard Worker for k, v in self.preset_metadata.items(): 206*da0073e9SAndroid Build Coastguard Worker self.add_metadata_json(k, v) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def stop_trace(self): 209*da0073e9SAndroid Build Coastguard Worker if self.execution_trace_observer: 210*da0073e9SAndroid Build Coastguard Worker self.execution_trace_observer.stop() 211*da0073e9SAndroid Build Coastguard Worker assert self.profiler is not None 212*da0073e9SAndroid Build Coastguard Worker self.profiler.__exit__(None, None, None) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def export_chrome_trace(self, path: str): 215*da0073e9SAndroid Build Coastguard Worker """ 216*da0073e9SAndroid Build Coastguard Worker Exports the collected trace in Chrome JSON format. If kineto is enabled, only 217*da0073e9SAndroid Build Coastguard Worker last cycle in schedule is exported. 218*da0073e9SAndroid Build Coastguard Worker """ 219*da0073e9SAndroid Build Coastguard Worker assert self.profiler 220*da0073e9SAndroid Build Coastguard Worker if path.endswith(".gz"): 221*da0073e9SAndroid Build Coastguard Worker fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) 222*da0073e9SAndroid Build Coastguard Worker fp.close() 223*da0073e9SAndroid Build Coastguard Worker retvalue = self.profiler.export_chrome_trace(fp.name) 224*da0073e9SAndroid Build Coastguard Worker with open(fp.name) as fin: 225*da0073e9SAndroid Build Coastguard Worker with gzip.open(path, "wt") as fout: 226*da0073e9SAndroid Build Coastguard Worker fout.writelines(fin) 227*da0073e9SAndroid Build Coastguard Worker os.remove(fp.name) 228*da0073e9SAndroid Build Coastguard Worker return retvalue 229*da0073e9SAndroid Build Coastguard Worker else: 230*da0073e9SAndroid Build Coastguard Worker return self.profiler.export_chrome_trace(path) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): 233*da0073e9SAndroid Build Coastguard Worker """Save stack traces to a file 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker Args: 236*da0073e9SAndroid Build Coastguard Worker path (str): save stacks file to this location; 237*da0073e9SAndroid Build Coastguard Worker metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" 238*da0073e9SAndroid Build Coastguard Worker """ 239*da0073e9SAndroid Build Coastguard Worker assert self.profiler 240*da0073e9SAndroid Build Coastguard Worker return self.profiler.export_stacks(path, metric) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def toggle_collection_dynamic( 243*da0073e9SAndroid Build Coastguard Worker self, enable: bool, activities: Iterable[ProfilerActivity] 244*da0073e9SAndroid Build Coastguard Worker ): 245*da0073e9SAndroid Build Coastguard Worker """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops 246*da0073e9SAndroid Build Coastguard Worker (CPU) and CUDA activity supported in Kineto 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker Args: 249*da0073e9SAndroid Build Coastguard Worker activities (iterable): list of activity groups to use in profiling, supported values: 250*da0073e9SAndroid Build Coastguard Worker ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA`` 251*da0073e9SAndroid Build Coastguard Worker Examples: 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 256*da0073e9SAndroid Build Coastguard Worker activities=[ 257*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CPU, 258*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CUDA, 259*da0073e9SAndroid Build Coastguard Worker ] 260*da0073e9SAndroid Build Coastguard Worker ) as p: 261*da0073e9SAndroid Build Coastguard Worker code_to_profile_0() 262*da0073e9SAndroid Build Coastguard Worker // turn off collection of all CUDA activity 263*da0073e9SAndroid Build Coastguard Worker p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA]) 264*da0073e9SAndroid Build Coastguard Worker code_to_profile_1() 265*da0073e9SAndroid Build Coastguard Worker // turn on collection of all CUDA activity 266*da0073e9SAndroid Build Coastguard Worker p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA]) 267*da0073e9SAndroid Build Coastguard Worker code_to_profile_2() 268*da0073e9SAndroid Build Coastguard Worker print(p.key_averages().table( 269*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total", row_limit=-1)) 270*da0073e9SAndroid Build Coastguard Worker """ 271*da0073e9SAndroid Build Coastguard Worker if not self.profiler: 272*da0073e9SAndroid Build Coastguard Worker return 273*da0073e9SAndroid Build Coastguard Worker self.profiler.toggle_collection_dynamic(enable, activities) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker def key_averages( 276*da0073e9SAndroid Build Coastguard Worker self, group_by_input_shape: bool = False, group_by_stack_n: int = 0 277*da0073e9SAndroid Build Coastguard Worker ): 278*da0073e9SAndroid Build Coastguard Worker """Averages events, grouping them by operator name and (optionally) input shapes and 279*da0073e9SAndroid Build Coastguard Worker stack. 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker .. note:: 282*da0073e9SAndroid Build Coastguard Worker To use shape/stack functionality make sure to set record_shapes/with_stack 283*da0073e9SAndroid Build Coastguard Worker when creating profiler context manager. 284*da0073e9SAndroid Build Coastguard Worker """ 285*da0073e9SAndroid Build Coastguard Worker assert self.profiler 286*da0073e9SAndroid Build Coastguard Worker return self.profiler.key_averages(group_by_input_shape, group_by_stack_n) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def events(self): 289*da0073e9SAndroid Build Coastguard Worker """ 290*da0073e9SAndroid Build Coastguard Worker Returns the list of unaggregated profiler events, 291*da0073e9SAndroid Build Coastguard Worker to be used in the trace callback or after the profiling is finished 292*da0073e9SAndroid Build Coastguard Worker """ 293*da0073e9SAndroid Build Coastguard Worker assert self.profiler 294*da0073e9SAndroid Build Coastguard Worker return self.profiler.function_events 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def add_metadata(self, key: str, value: str): 297*da0073e9SAndroid Build Coastguard Worker """ 298*da0073e9SAndroid Build Coastguard Worker Adds a user defined metadata with a string key and a string value 299*da0073e9SAndroid Build Coastguard Worker into the trace file 300*da0073e9SAndroid Build Coastguard Worker """ 301*da0073e9SAndroid Build Coastguard Worker wrapped_value = '"' + value.replace('"', '\\"') + '"' 302*da0073e9SAndroid Build Coastguard Worker torch.autograd._add_metadata_json(key, wrapped_value) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def add_metadata_json(self, key: str, value: str): 305*da0073e9SAndroid Build Coastguard Worker """ 306*da0073e9SAndroid Build Coastguard Worker Adds a user defined metadata with a string key and a valid json value 307*da0073e9SAndroid Build Coastguard Worker into the trace file 308*da0073e9SAndroid Build Coastguard Worker """ 309*da0073e9SAndroid Build Coastguard Worker torch.autograd._add_metadata_json(key, value) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker def preset_metadata_json(self, key: str, value: str): 312*da0073e9SAndroid Build Coastguard Worker """ 313*da0073e9SAndroid Build Coastguard Worker Preset a user defined metadata when the profiler is not started 314*da0073e9SAndroid Build Coastguard Worker and added into the trace file later. 315*da0073e9SAndroid Build Coastguard Worker Metadata is in the format of a string key and a valid json value 316*da0073e9SAndroid Build Coastguard Worker """ 317*da0073e9SAndroid Build Coastguard Worker self.preset_metadata[key] = value 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def _get_distributed_info(self): 320*da0073e9SAndroid Build Coastguard Worker import torch.distributed as dist 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker if not dist.is_available() or not dist.is_initialized(): 323*da0073e9SAndroid Build Coastguard Worker return None 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker backend = dist.get_backend() 326*da0073e9SAndroid Build Coastguard Worker dist_info = { 327*da0073e9SAndroid Build Coastguard Worker "backend": backend, 328*da0073e9SAndroid Build Coastguard Worker "rank": dist.get_rank(), 329*da0073e9SAndroid Build Coastguard Worker "world_size": dist.get_world_size(), 330*da0073e9SAndroid Build Coastguard Worker "pg_count": dist.get_pg_count(), 331*da0073e9SAndroid Build Coastguard Worker "pg_config": dist.distributed_c10d._get_all_pg_configs(), 332*da0073e9SAndroid Build Coastguard Worker } 333*da0073e9SAndroid Build Coastguard Worker if backend == "nccl": 334*da0073e9SAndroid Build Coastguard Worker nccl_version = torch.cuda.nccl.version() 335*da0073e9SAndroid Build Coastguard Worker dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) 336*da0073e9SAndroid Build Coastguard Worker return dist_info 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def _memory_profile(self) -> MemoryProfile: 339*da0073e9SAndroid Build Coastguard Worker required = ("record_shapes", "profile_memory", "with_stack") 340*da0073e9SAndroid Build Coastguard Worker missing = [f"{i}=True" for i in required if not getattr(self, i)] 341*da0073e9SAndroid Build Coastguard Worker if missing: 342*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"{', '.join(missing)} required for memory profiling.") 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker assert self.profiler is not None and self.profiler.kineto_results is not None 345*da0073e9SAndroid Build Coastguard Worker return MemoryProfile(self.profiler.kineto_results) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: 348*da0073e9SAndroid Build Coastguard Worker """Export memory event information from the profiler collected 349*da0073e9SAndroid Build Coastguard Worker tree for a given device, and export a timeline plot. There are 3 350*da0073e9SAndroid Build Coastguard Worker exportable files using ``export_memory_timeline``, each controlled by the 351*da0073e9SAndroid Build Coastguard Worker ``path``'s suffix. 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline 354*da0073e9SAndroid Build Coastguard Worker plot will be embedded as a PNG file in the HTML file. 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker - For plot points consisting of ``[times, [sizes by category]]``, where 357*da0073e9SAndroid Build Coastguard Worker ``times`` are timestamps and ``sizes`` are memory usage for each category. 358*da0073e9SAndroid Build Coastguard Worker The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON 359*da0073e9SAndroid Build Coastguard Worker (``.json.gz``) depending on the suffix. 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory 362*da0073e9SAndroid Build Coastguard Worker event will consist of ``(timestamp, action, numbytes, category)``, where 363*da0073e9SAndroid Build Coastguard Worker ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``, 364*da0073e9SAndroid Build Coastguard Worker and ``category`` is one of the enums from 365*da0073e9SAndroid Build Coastguard Worker ``torch.profiler._memory_profiler.Category``. 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker Output: Memory timeline written as gzipped JSON, JSON, or HTML. 368*da0073e9SAndroid Build Coastguard Worker """ 369*da0073e9SAndroid Build Coastguard Worker # Default to device 0, if unset. Fallback on cpu. 370*da0073e9SAndroid Build Coastguard Worker if device is None and self.use_device and self.use_device != "cuda": 371*da0073e9SAndroid Build Coastguard Worker device = self.use_device + ":0" 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker if device is None: 374*da0073e9SAndroid Build Coastguard Worker device = "cuda:0" if torch.cuda.is_available() else "cpu" 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker # Construct the memory timeline plot data 377*da0073e9SAndroid Build Coastguard Worker self.mem_tl = MemoryProfileTimeline(self._memory_profile()) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker # Depending on the file suffix, save the data as json.gz or json. 380*da0073e9SAndroid Build Coastguard Worker # For html, we can embed the image into an HTML file. 381*da0073e9SAndroid Build Coastguard Worker if path.endswith(".html"): 382*da0073e9SAndroid Build Coastguard Worker self.mem_tl.export_memory_timeline_html(path, device) 383*da0073e9SAndroid Build Coastguard Worker elif path.endswith(".gz"): 384*da0073e9SAndroid Build Coastguard Worker fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) 385*da0073e9SAndroid Build Coastguard Worker fp.close() 386*da0073e9SAndroid Build Coastguard Worker if path.endswith("raw.json.gz"): 387*da0073e9SAndroid Build Coastguard Worker self.mem_tl.export_memory_timeline_raw(fp.name, device) 388*da0073e9SAndroid Build Coastguard Worker else: 389*da0073e9SAndroid Build Coastguard Worker self.mem_tl.export_memory_timeline(fp.name, device) 390*da0073e9SAndroid Build Coastguard Worker with open(fp.name) as fin: 391*da0073e9SAndroid Build Coastguard Worker with gzip.open(path, "wt") as fout: 392*da0073e9SAndroid Build Coastguard Worker fout.writelines(fin) 393*da0073e9SAndroid Build Coastguard Worker os.remove(fp.name) 394*da0073e9SAndroid Build Coastguard Worker else: 395*da0073e9SAndroid Build Coastguard Worker self.mem_tl.export_memory_timeline(path, device) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Workerclass ProfilerAction(Enum): 399*da0073e9SAndroid Build Coastguard Worker """ 400*da0073e9SAndroid Build Coastguard Worker Profiler actions that can be taken at the specified intervals 401*da0073e9SAndroid Build Coastguard Worker """ 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker NONE = 0 404*da0073e9SAndroid Build Coastguard Worker WARMUP = 1 405*da0073e9SAndroid Build Coastguard Worker RECORD = 2 406*da0073e9SAndroid Build Coastguard Worker RECORD_AND_SAVE = 3 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Workerdef schedule( 410*da0073e9SAndroid Build Coastguard Worker *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0 411*da0073e9SAndroid Build Coastguard Worker) -> Callable: 412*da0073e9SAndroid Build Coastguard Worker """ 413*da0073e9SAndroid Build Coastguard Worker Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip 414*da0073e9SAndroid Build Coastguard Worker the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps, 415*da0073e9SAndroid Build Coastguard Worker then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps. 416*da0073e9SAndroid Build Coastguard Worker The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that 417*da0073e9SAndroid Build Coastguard Worker the cycles will continue until the profiling is finished. 418*da0073e9SAndroid Build Coastguard Worker """ 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def schedule_fn(step: int) -> ProfilerAction: 421*da0073e9SAndroid Build Coastguard Worker assert step >= 0 422*da0073e9SAndroid Build Coastguard Worker if step < skip_first: 423*da0073e9SAndroid Build Coastguard Worker return ProfilerAction.NONE 424*da0073e9SAndroid Build Coastguard Worker else: 425*da0073e9SAndroid Build Coastguard Worker step -= skip_first 426*da0073e9SAndroid Build Coastguard Worker num_steps = wait + warmup + active 427*da0073e9SAndroid Build Coastguard Worker if repeat > 0 and step / num_steps >= repeat: 428*da0073e9SAndroid Build Coastguard Worker return ProfilerAction.NONE 429*da0073e9SAndroid Build Coastguard Worker mod_step = step % num_steps 430*da0073e9SAndroid Build Coastguard Worker if mod_step < wait: 431*da0073e9SAndroid Build Coastguard Worker return ProfilerAction.NONE 432*da0073e9SAndroid Build Coastguard Worker elif mod_step < wait + warmup: 433*da0073e9SAndroid Build Coastguard Worker return ProfilerAction.WARMUP 434*da0073e9SAndroid Build Coastguard Worker else: 435*da0073e9SAndroid Build Coastguard Worker return ( 436*da0073e9SAndroid Build Coastguard Worker ProfilerAction.RECORD 437*da0073e9SAndroid Build Coastguard Worker if mod_step < num_steps - 1 438*da0073e9SAndroid Build Coastguard Worker else ProfilerAction.RECORD_AND_SAVE 439*da0073e9SAndroid Build Coastguard Worker ) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker assert ( 442*da0073e9SAndroid Build Coastguard Worker wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 443*da0073e9SAndroid Build Coastguard Worker ), "Invalid profiler schedule arguments" 444*da0073e9SAndroid Build Coastguard Worker if warmup == 0: 445*da0073e9SAndroid Build Coastguard Worker warn("Profiler won't be using warmup, this can skew profiler results") 446*da0073e9SAndroid Build Coastguard Worker return schedule_fn 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Workerdef _default_schedule_fn(_: int) -> ProfilerAction: 450*da0073e9SAndroid Build Coastguard Worker """ 451*da0073e9SAndroid Build Coastguard Worker Default profiler behavior - immediately starts recording the events, 452*da0073e9SAndroid Build Coastguard Worker keeps doing it on every profiler step. 453*da0073e9SAndroid Build Coastguard Worker """ 454*da0073e9SAndroid Build Coastguard Worker return ProfilerAction.RECORD 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Workerdef tensorboard_trace_handler( 458*da0073e9SAndroid Build Coastguard Worker dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False 459*da0073e9SAndroid Build Coastguard Worker): 460*da0073e9SAndroid Build Coastguard Worker """ 461*da0073e9SAndroid Build Coastguard Worker Outputs tracing files to directory of ``dir_name``, then that directory can be 462*da0073e9SAndroid Build Coastguard Worker directly delivered to tensorboard as logdir. 463*da0073e9SAndroid Build Coastguard Worker ``worker_name`` should be unique for each worker in distributed scenario, 464*da0073e9SAndroid Build Coastguard Worker it will be set to '[hostname]_[pid]' by default. 465*da0073e9SAndroid Build Coastguard Worker """ 466*da0073e9SAndroid Build Coastguard Worker import os 467*da0073e9SAndroid Build Coastguard Worker import socket 468*da0073e9SAndroid Build Coastguard Worker import time 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def handler_fn(prof) -> None: 471*da0073e9SAndroid Build Coastguard Worker nonlocal worker_name 472*da0073e9SAndroid Build Coastguard Worker if not os.path.isdir(dir_name): 473*da0073e9SAndroid Build Coastguard Worker try: 474*da0073e9SAndroid Build Coastguard Worker os.makedirs(dir_name, exist_ok=True) 475*da0073e9SAndroid Build Coastguard Worker except Exception as e: 476*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Can't create directory: " + dir_name) from e 477*da0073e9SAndroid Build Coastguard Worker if not worker_name: 478*da0073e9SAndroid Build Coastguard Worker worker_name = f"{socket.gethostname()}_{os.getpid()}" 479*da0073e9SAndroid Build Coastguard Worker # Use nanosecond here to avoid naming clash when exporting the trace 480*da0073e9SAndroid Build Coastguard Worker file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json" 481*da0073e9SAndroid Build Coastguard Worker if use_gzip: 482*da0073e9SAndroid Build Coastguard Worker file_name = file_name + ".gz" 483*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(os.path.join(dir_name, file_name)) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker return handler_fn 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Workerclass profile(_KinetoProfile): 489*da0073e9SAndroid Build Coastguard Worker """Profiler context manager. 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker Args: 492*da0073e9SAndroid Build Coastguard Worker activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: 493*da0073e9SAndroid Build Coastguard Worker ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, 494*da0073e9SAndroid Build Coastguard Worker ``torch.profiler.ProfilerActivity.XPU``. 495*da0073e9SAndroid Build Coastguard Worker Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA 496*da0073e9SAndroid Build Coastguard Worker or (when available) ProfilerActivity.XPU. 497*da0073e9SAndroid Build Coastguard Worker schedule (Callable): callable that takes step (int) as a single parameter and returns 498*da0073e9SAndroid Build Coastguard Worker ``ProfilerAction`` value that specifies the profiler action to perform at each step. 499*da0073e9SAndroid Build Coastguard Worker on_trace_ready (Callable): callable that is called at each step when ``schedule`` 500*da0073e9SAndroid Build Coastguard Worker returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. 501*da0073e9SAndroid Build Coastguard Worker record_shapes (bool): save information about operator's input shapes. 502*da0073e9SAndroid Build Coastguard Worker profile_memory (bool): track tensor memory allocation/deallocation. 503*da0073e9SAndroid Build Coastguard Worker with_stack (bool): record source information (file and line number) for the ops. 504*da0073e9SAndroid Build Coastguard Worker with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators 505*da0073e9SAndroid Build Coastguard Worker (matrix multiplication and 2D convolution). 506*da0073e9SAndroid Build Coastguard Worker with_modules (bool): record module hierarchy (including function names) 507*da0073e9SAndroid Build Coastguard Worker corresponding to the callstack of the op. e.g. If module A's forward call's 508*da0073e9SAndroid Build Coastguard Worker module B's forward which contains an aten::add op, 509*da0073e9SAndroid Build Coastguard Worker then aten::add's module hierarchy is A.B 510*da0073e9SAndroid Build Coastguard Worker Note that this support exist, at the moment, only for TorchScript models 511*da0073e9SAndroid Build Coastguard Worker and not eager mode models. 512*da0073e9SAndroid Build Coastguard Worker experimental_config (_ExperimentalConfig) : A set of experimental options 513*da0073e9SAndroid Build Coastguard Worker used for Kineto library features. Note, backward compatibility is not guaranteed. 514*da0073e9SAndroid Build Coastguard Worker execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. 515*da0073e9SAndroid Build Coastguard Worker `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based 516*da0073e9SAndroid Build Coastguard Worker representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. 517*da0073e9SAndroid Build Coastguard Worker When this argument is included the observer start() and stop() will be called for the 518*da0073e9SAndroid Build Coastguard Worker same time window as PyTorch profiler. See the examples section below for a code sample. 519*da0073e9SAndroid Build Coastguard Worker acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles 520*da0073e9SAndroid Build Coastguard Worker use_cuda (bool): 521*da0073e9SAndroid Build Coastguard Worker .. deprecated:: 1.8.1 522*da0073e9SAndroid Build Coastguard Worker use ``activities`` instead. 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker .. note:: 525*da0073e9SAndroid Build Coastguard Worker Use :func:`~torch.profiler.schedule` to generate the callable schedule. 526*da0073e9SAndroid Build Coastguard Worker Non-default schedules are useful when profiling long training jobs 527*da0073e9SAndroid Build Coastguard Worker and allow the user to obtain multiple traces at the different iterations 528*da0073e9SAndroid Build Coastguard Worker of the training process. 529*da0073e9SAndroid Build Coastguard Worker The default schedule simply records all the events continuously for the 530*da0073e9SAndroid Build Coastguard Worker duration of the context manager. 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker .. note:: 533*da0073e9SAndroid Build Coastguard Worker Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker After profiling, result files can be found in the specified directory. Use the command: 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker ``tensorboard --logdir dir_name`` 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker to see the results in TensorBoard. 542*da0073e9SAndroid Build Coastguard Worker For more information, see 543*da0073e9SAndroid Build Coastguard Worker `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__ 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker .. note:: 546*da0073e9SAndroid Build Coastguard Worker Enabling shape and stack tracing results in additional overhead. 547*da0073e9SAndroid Build Coastguard Worker When record_shapes=True is specified, profiler will temporarily hold references to the tensors; 548*da0073e9SAndroid Build Coastguard Worker that may further prevent certain optimizations that depend on the reference count and introduce 549*da0073e9SAndroid Build Coastguard Worker extra tensor copies. 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker Examples: 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 557*da0073e9SAndroid Build Coastguard Worker activities=[ 558*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CPU, 559*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CUDA, 560*da0073e9SAndroid Build Coastguard Worker ] 561*da0073e9SAndroid Build Coastguard Worker ) as p: 562*da0073e9SAndroid Build Coastguard Worker code_to_profile() 563*da0073e9SAndroid Build Coastguard Worker print(p.key_averages().table( 564*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total", row_limit=-1)) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker # Non-default profiler schedule allows user to turn profiler on and off 571*da0073e9SAndroid Build Coastguard Worker # on different iterations of the training loop; 572*da0073e9SAndroid Build Coastguard Worker # trace_handler is called every time a new trace becomes available 573*da0073e9SAndroid Build Coastguard Worker def trace_handler(prof): 574*da0073e9SAndroid Build Coastguard Worker print(prof.key_averages().table( 575*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total", row_limit=-1)) 576*da0073e9SAndroid Build Coastguard Worker # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 579*da0073e9SAndroid Build Coastguard Worker activities=[ 580*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CPU, 581*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CUDA, 582*da0073e9SAndroid Build Coastguard Worker ], 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker # In this example with wait=1, warmup=1, active=2, repeat=1, 585*da0073e9SAndroid Build Coastguard Worker # profiler will skip the first step/iteration, 586*da0073e9SAndroid Build Coastguard Worker # start warming up on the second, record 587*da0073e9SAndroid Build Coastguard Worker # the third and the forth iterations, 588*da0073e9SAndroid Build Coastguard Worker # after which the trace will become available 589*da0073e9SAndroid Build Coastguard Worker # and on_trace_ready (when set) is called; 590*da0073e9SAndroid Build Coastguard Worker # the cycle repeats starting with the next step 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule( 593*da0073e9SAndroid Build Coastguard Worker wait=1, 594*da0073e9SAndroid Build Coastguard Worker warmup=1, 595*da0073e9SAndroid Build Coastguard Worker active=2, 596*da0073e9SAndroid Build Coastguard Worker repeat=1), 597*da0073e9SAndroid Build Coastguard Worker on_trace_ready=trace_handler 598*da0073e9SAndroid Build Coastguard Worker # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') 599*da0073e9SAndroid Build Coastguard Worker # used when outputting for tensorboard 600*da0073e9SAndroid Build Coastguard Worker ) as p: 601*da0073e9SAndroid Build Coastguard Worker for iter in range(N): 602*da0073e9SAndroid Build Coastguard Worker code_iteration_to_profile(iter) 603*da0073e9SAndroid Build Coastguard Worker # send a signal to the profiler that the next iteration has started 604*da0073e9SAndroid Build Coastguard Worker p.step() 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 611*da0073e9SAndroid Build Coastguard Worker ... 612*da0073e9SAndroid Build Coastguard Worker execution_trace_observer=( 613*da0073e9SAndroid Build Coastguard Worker ExecutionTraceObserver().register_callback("./execution_trace.json") 614*da0073e9SAndroid Build Coastguard Worker ), 615*da0073e9SAndroid Build Coastguard Worker ) as p: 616*da0073e9SAndroid Build Coastguard Worker for iter in range(N): 617*da0073e9SAndroid Build Coastguard Worker code_iteration_to_profile(iter) 618*da0073e9SAndroid Build Coastguard Worker p.step() 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py. 621*da0073e9SAndroid Build Coastguard Worker Note: One can also pass any object satisfying the _ITraceObserver interface. 622*da0073e9SAndroid Build Coastguard Worker """ 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker def __init__( 625*da0073e9SAndroid Build Coastguard Worker self, 626*da0073e9SAndroid Build Coastguard Worker *, 627*da0073e9SAndroid Build Coastguard Worker activities: Optional[Iterable[ProfilerActivity]] = None, 628*da0073e9SAndroid Build Coastguard Worker schedule: Optional[Callable[[int], ProfilerAction]] = None, 629*da0073e9SAndroid Build Coastguard Worker on_trace_ready: Optional[Callable[..., Any]] = None, 630*da0073e9SAndroid Build Coastguard Worker record_shapes: bool = False, 631*da0073e9SAndroid Build Coastguard Worker profile_memory: bool = False, 632*da0073e9SAndroid Build Coastguard Worker with_stack: bool = False, 633*da0073e9SAndroid Build Coastguard Worker with_flops: bool = False, 634*da0073e9SAndroid Build Coastguard Worker with_modules: bool = False, 635*da0073e9SAndroid Build Coastguard Worker experimental_config: Optional[_ExperimentalConfig] = None, 636*da0073e9SAndroid Build Coastguard Worker execution_trace_observer: Optional[_ITraceObserver] = None, 637*da0073e9SAndroid Build Coastguard Worker acc_events: bool = False, 638*da0073e9SAndroid Build Coastguard Worker # deprecated: 639*da0073e9SAndroid Build Coastguard Worker use_cuda: Optional[bool] = None, 640*da0073e9SAndroid Build Coastguard Worker ): 641*da0073e9SAndroid Build Coastguard Worker activities_set = set(activities) if activities else supported_activities() 642*da0073e9SAndroid Build Coastguard Worker if use_cuda is not None: 643*da0073e9SAndroid Build Coastguard Worker warn( 644*da0073e9SAndroid Build Coastguard Worker "`use_cuda` is deprecated, use `activities` argument instead", 645*da0073e9SAndroid Build Coastguard Worker FutureWarning, 646*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 647*da0073e9SAndroid Build Coastguard Worker ) 648*da0073e9SAndroid Build Coastguard Worker if use_cuda: 649*da0073e9SAndroid Build Coastguard Worker activities_set.add(ProfilerActivity.CUDA) 650*da0073e9SAndroid Build Coastguard Worker elif ProfilerActivity.CUDA in activities_set: 651*da0073e9SAndroid Build Coastguard Worker activities_set.remove(ProfilerActivity.CUDA) 652*da0073e9SAndroid Build Coastguard Worker assert len(activities_set) > 0, "No valid profiler activities found" 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker super().__init__( 655*da0073e9SAndroid Build Coastguard Worker activities=activities, 656*da0073e9SAndroid Build Coastguard Worker record_shapes=record_shapes, 657*da0073e9SAndroid Build Coastguard Worker profile_memory=profile_memory, 658*da0073e9SAndroid Build Coastguard Worker with_stack=with_stack, 659*da0073e9SAndroid Build Coastguard Worker with_flops=with_flops, 660*da0073e9SAndroid Build Coastguard Worker with_modules=with_modules, 661*da0073e9SAndroid Build Coastguard Worker experimental_config=experimental_config, 662*da0073e9SAndroid Build Coastguard Worker execution_trace_observer=execution_trace_observer, 663*da0073e9SAndroid Build Coastguard Worker acc_events=acc_events, 664*da0073e9SAndroid Build Coastguard Worker ) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker if schedule: 667*da0073e9SAndroid Build Coastguard Worker self.schedule = schedule 668*da0073e9SAndroid Build Coastguard Worker # add step markers into the trace and table view 669*da0073e9SAndroid Build Coastguard Worker self.record_steps = True 670*da0073e9SAndroid Build Coastguard Worker else: 671*da0073e9SAndroid Build Coastguard Worker self.schedule = _default_schedule_fn 672*da0073e9SAndroid Build Coastguard Worker self.record_steps = False 673*da0073e9SAndroid Build Coastguard Worker self.on_trace_ready = on_trace_ready 674*da0073e9SAndroid Build Coastguard Worker self.step_num = 0 675*da0073e9SAndroid Build Coastguard Worker self.current_action = self.schedule(self.step_num) 676*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn: Optional[prof.record_function] = None 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker self.action_map: Dict[ 679*da0073e9SAndroid Build Coastguard Worker Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any] 680*da0073e9SAndroid Build Coastguard Worker ] = { 681*da0073e9SAndroid Build Coastguard Worker # key is (prev_action, current_action), value is action list corresponding to the state pair. 682*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.NONE, ProfilerAction.NONE): [], 683*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace], 684*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.NONE, ProfilerAction.RECORD): [ 685*da0073e9SAndroid Build Coastguard Worker self.prepare_trace, 686*da0073e9SAndroid Build Coastguard Worker self.start_trace, 687*da0073e9SAndroid Build Coastguard Worker ], 688*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [ 689*da0073e9SAndroid Build Coastguard Worker self.prepare_trace, 690*da0073e9SAndroid Build Coastguard Worker self.start_trace, 691*da0073e9SAndroid Build Coastguard Worker ], 692*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.WARMUP, ProfilerAction.NONE): [ 693*da0073e9SAndroid Build Coastguard Worker partial(warn, "Incorrect schedule: WARMUP followed by NONE"), 694*da0073e9SAndroid Build Coastguard Worker self.start_trace, 695*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 696*da0073e9SAndroid Build Coastguard Worker ], 697*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [], 698*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace], 699*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace], 700*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD, ProfilerAction.NONE): [ 701*da0073e9SAndroid Build Coastguard Worker partial(warn, "Incorrect schedule: RECORD followed by NONE"), 702*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 703*da0073e9SAndroid Build Coastguard Worker ], 704*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD, ProfilerAction.WARMUP): [ 705*da0073e9SAndroid Build Coastguard Worker partial(warn, "Incorrect schedule: RECORD followed by WARMUP"), 706*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 707*da0073e9SAndroid Build Coastguard Worker ], 708*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD, ProfilerAction.RECORD): [], 709*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [], 710*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [ 711*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 712*da0073e9SAndroid Build Coastguard Worker self._trace_ready, 713*da0073e9SAndroid Build Coastguard Worker ], 714*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [ 715*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 716*da0073e9SAndroid Build Coastguard Worker self._trace_ready, 717*da0073e9SAndroid Build Coastguard Worker self.prepare_trace, 718*da0073e9SAndroid Build Coastguard Worker ], 719*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [ 720*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 721*da0073e9SAndroid Build Coastguard Worker self._trace_ready, 722*da0073e9SAndroid Build Coastguard Worker self.prepare_trace, 723*da0073e9SAndroid Build Coastguard Worker self.start_trace, 724*da0073e9SAndroid Build Coastguard Worker ], 725*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [ 726*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 727*da0073e9SAndroid Build Coastguard Worker self._trace_ready, 728*da0073e9SAndroid Build Coastguard Worker self.prepare_trace, 729*da0073e9SAndroid Build Coastguard Worker self.start_trace, 730*da0073e9SAndroid Build Coastguard Worker ], 731*da0073e9SAndroid Build Coastguard Worker # used for exit action 732*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace], 733*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready], 734*da0073e9SAndroid Build Coastguard Worker (ProfilerAction.RECORD_AND_SAVE, None): [ 735*da0073e9SAndroid Build Coastguard Worker self.stop_trace, 736*da0073e9SAndroid Build Coastguard Worker self._trace_ready, 737*da0073e9SAndroid Build Coastguard Worker ], 738*da0073e9SAndroid Build Coastguard Worker } 739*da0073e9SAndroid Build Coastguard Worker # Start tracking increments to profiler step, this will be used 740*da0073e9SAndroid Build Coastguard Worker # by Kineto 741*da0073e9SAndroid Build Coastguard Worker prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 744*da0073e9SAndroid Build Coastguard Worker self.start() 745*da0073e9SAndroid Build Coastguard Worker return self 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 748*da0073e9SAndroid Build Coastguard Worker self.stop() 749*da0073e9SAndroid Build Coastguard Worker prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME) 750*da0073e9SAndroid Build Coastguard Worker if self.execution_trace_observer: 751*da0073e9SAndroid Build Coastguard Worker self.execution_trace_observer.cleanup() 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker def start(self): 754*da0073e9SAndroid Build Coastguard Worker self._transit_action(ProfilerAction.NONE, self.current_action) 755*da0073e9SAndroid Build Coastguard Worker if self.record_steps: 756*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn = prof.record_function( 757*da0073e9SAndroid Build Coastguard Worker "ProfilerStep#" + str(self.step_num) 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn.__enter__() 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker def stop(self): 762*da0073e9SAndroid Build Coastguard Worker if self.record_steps and self.step_rec_fn: 763*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn.__exit__(None, None, None) 764*da0073e9SAndroid Build Coastguard Worker self._transit_action(self.current_action, None) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker def step(self): 767*da0073e9SAndroid Build Coastguard Worker """ 768*da0073e9SAndroid Build Coastguard Worker Signals the profiler that the next profiling step has started. 769*da0073e9SAndroid Build Coastguard Worker """ 770*da0073e9SAndroid Build Coastguard Worker if self.record_steps and self.step_rec_fn: 771*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn.__exit__(None, None, None) 772*da0073e9SAndroid Build Coastguard Worker prev_action = self.current_action 773*da0073e9SAndroid Build Coastguard Worker self.step_num += 1 774*da0073e9SAndroid Build Coastguard Worker self.current_action = self.schedule(self.step_num) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker self._transit_action(prev_action, self.current_action) 777*da0073e9SAndroid Build Coastguard Worker prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME) 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker if self.record_steps: 780*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn = prof.record_function( 781*da0073e9SAndroid Build Coastguard Worker "ProfilerStep#" + str(self.step_num) 782*da0073e9SAndroid Build Coastguard Worker ) 783*da0073e9SAndroid Build Coastguard Worker self.step_rec_fn.__enter__() 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def _trace_ready(self): 786*da0073e9SAndroid Build Coastguard Worker if self.on_trace_ready: 787*da0073e9SAndroid Build Coastguard Worker self.on_trace_ready(self) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker def _transit_action(self, prev_action, current_action): 790*da0073e9SAndroid Build Coastguard Worker action_list = self.action_map.get((prev_action, current_action)) 791*da0073e9SAndroid Build Coastguard Worker if action_list: 792*da0073e9SAndroid Build Coastguard Worker for action in action_list: 793*da0073e9SAndroid Build Coastguard Worker action() 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker def _stats(self) -> Optional[prof._ProfilerStats]: 796*da0073e9SAndroid Build Coastguard Worker if self.profiler is None: 797*da0073e9SAndroid Build Coastguard Worker return None 798*da0073e9SAndroid Build Coastguard Worker return self.profiler._stats 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Workerclass ExecutionTraceObserver(_ITraceObserver): 802*da0073e9SAndroid Build Coastguard Worker """Execution Trace Observer 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker Each process can have a single ExecutionTraceObserver instance. The observer 805*da0073e9SAndroid Build Coastguard Worker can be added to record function callbacks via calling register_callback() 806*da0073e9SAndroid Build Coastguard Worker explicitly. Without calling unregister_callback(), repeated calls to 807*da0073e9SAndroid Build Coastguard Worker register_callback() will not add additional observers to record function 808*da0073e9SAndroid Build Coastguard Worker callbacks. Once an ExecutionTraceObserver is created, the start() and stop() 809*da0073e9SAndroid Build Coastguard Worker methods control when the event data is recorded. 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker Deleting or calling unregister_callback() will remove the observer from the 812*da0073e9SAndroid Build Coastguard Worker record function callbacks, finalize the output file, and will stop 813*da0073e9SAndroid Build Coastguard Worker incurring any overheads. 814*da0073e9SAndroid Build Coastguard Worker """ 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 817*da0073e9SAndroid Build Coastguard Worker """ 818*da0073e9SAndroid Build Coastguard Worker Initializes the default states. 819*da0073e9SAndroid Build Coastguard Worker """ 820*da0073e9SAndroid Build Coastguard Worker self._registered = False 821*da0073e9SAndroid Build Coastguard Worker self._execution_trace_running = False 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker def __del__(self): 824*da0073e9SAndroid Build Coastguard Worker """ 825*da0073e9SAndroid Build Coastguard Worker Calls unregister_callback() to make sure to finalize outputs. 826*da0073e9SAndroid Build Coastguard Worker """ 827*da0073e9SAndroid Build Coastguard Worker self.unregister_callback() 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker def register_callback(self, output_file_path: str) -> Self: 830*da0073e9SAndroid Build Coastguard Worker """ 831*da0073e9SAndroid Build Coastguard Worker Adds ET observer to record function callbacks. The data will be 832*da0073e9SAndroid Build Coastguard Worker written to output_file_path. 833*da0073e9SAndroid Build Coastguard Worker """ 834*da0073e9SAndroid Build Coastguard Worker if not self._registered: 835*da0073e9SAndroid Build Coastguard Worker self._output_file_path = output_file_path 836*da0073e9SAndroid Build Coastguard Worker self._registered = _add_execution_trace_observer(output_file_path) 837*da0073e9SAndroid Build Coastguard Worker return self 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def unregister_callback(self): 840*da0073e9SAndroid Build Coastguard Worker """ 841*da0073e9SAndroid Build Coastguard Worker Removes ET observer from record function callbacks. 842*da0073e9SAndroid Build Coastguard Worker """ 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker def _save_triton_kernels(): 845*da0073e9SAndroid Build Coastguard Worker # Save the kernel paths for the generated kernels 846*da0073e9SAndroid Build Coastguard Worker from torch._inductor.codecache import PyCodeCache as PyCodeCache 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker kernel_files = [ 849*da0073e9SAndroid Build Coastguard Worker v.__file__ 850*da0073e9SAndroid Build Coastguard Worker for v in PyCodeCache.cache.values() 851*da0073e9SAndroid Build Coastguard Worker if getattr(v, "__file__", None) is not None 852*da0073e9SAndroid Build Coastguard Worker ] 853*da0073e9SAndroid Build Coastguard Worker work_dir, file_name = os.path.split(self._output_file_path) 854*da0073e9SAndroid Build Coastguard Worker resource_dir = os.path.join( 855*da0073e9SAndroid Build Coastguard Worker work_dir, os.path.splitext(file_name)[0] + "_resources" 856*da0073e9SAndroid Build Coastguard Worker ) 857*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(resource_dir): 858*da0073e9SAndroid Build Coastguard Worker os.mkdir(resource_dir) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker for kernel_file in kernel_files: 861*da0073e9SAndroid Build Coastguard Worker if kernel_file is None: 862*da0073e9SAndroid Build Coastguard Worker continue 863*da0073e9SAndroid Build Coastguard Worker path, name = os.path.split(kernel_file) 864*da0073e9SAndroid Build Coastguard Worker dst = os.path.join(resource_dir, name) 865*da0073e9SAndroid Build Coastguard Worker shutil.copyfile(kernel_file, dst) 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker if self._registered: 868*da0073e9SAndroid Build Coastguard Worker self.stop() 869*da0073e9SAndroid Build Coastguard Worker try: 870*da0073e9SAndroid Build Coastguard Worker _save_triton_kernels() 871*da0073e9SAndroid Build Coastguard Worker except Exception as e: 872*da0073e9SAndroid Build Coastguard Worker warn(f"Execution trace failed to save kernels: {e}") 873*da0073e9SAndroid Build Coastguard Worker _remove_execution_trace_observer() 874*da0073e9SAndroid Build Coastguard Worker self._registered = False 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker @property 877*da0073e9SAndroid Build Coastguard Worker def is_registered(self): 878*da0073e9SAndroid Build Coastguard Worker """ 879*da0073e9SAndroid Build Coastguard Worker Returns True if the execution trace observer is registered, otherwise False. 880*da0073e9SAndroid Build Coastguard Worker """ 881*da0073e9SAndroid Build Coastguard Worker return self._registered 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker def is_running(self): 884*da0073e9SAndroid Build Coastguard Worker """ 885*da0073e9SAndroid Build Coastguard Worker Returns True if the observer is running, otherwise False. 886*da0073e9SAndroid Build Coastguard Worker """ 887*da0073e9SAndroid Build Coastguard Worker return self._execution_trace_running 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker def start(self): 890*da0073e9SAndroid Build Coastguard Worker """ 891*da0073e9SAndroid Build Coastguard Worker Starts to capture. 892*da0073e9SAndroid Build Coastguard Worker """ 893*da0073e9SAndroid Build Coastguard Worker if self._registered and not self._execution_trace_running: 894*da0073e9SAndroid Build Coastguard Worker _enable_execution_trace_observer() 895*da0073e9SAndroid Build Coastguard Worker self._execution_trace_running = True 896*da0073e9SAndroid Build Coastguard Worker self._record_pg_config() 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker def stop(self): 899*da0073e9SAndroid Build Coastguard Worker """ 900*da0073e9SAndroid Build Coastguard Worker Stops to capture. 901*da0073e9SAndroid Build Coastguard Worker """ 902*da0073e9SAndroid Build Coastguard Worker if self._execution_trace_running: 903*da0073e9SAndroid Build Coastguard Worker _disable_execution_trace_observer() 904*da0073e9SAndroid Build Coastguard Worker self._execution_trace_running = False 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker def cleanup(self): 907*da0073e9SAndroid Build Coastguard Worker """ 908*da0073e9SAndroid Build Coastguard Worker Calls unregister_callback() to make sure to finalize outputs. 909*da0073e9SAndroid Build Coastguard Worker """ 910*da0073e9SAndroid Build Coastguard Worker self.unregister_callback() 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker def get_output_file_path(self) -> str: 913*da0073e9SAndroid Build Coastguard Worker """ 914*da0073e9SAndroid Build Coastguard Worker Returns the output file name. 915*da0073e9SAndroid Build Coastguard Worker """ 916*da0073e9SAndroid Build Coastguard Worker if self.is_registered: 917*da0073e9SAndroid Build Coastguard Worker return self._output_file_path 918*da0073e9SAndroid Build Coastguard Worker else: 919*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 920*da0073e9SAndroid Build Coastguard Worker "A callback to the ET profiler needs to be registered " 921*da0073e9SAndroid Build Coastguard Worker "first before getting the output file path" 922*da0073e9SAndroid Build Coastguard Worker ) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker def _record_pg_config(self) -> None: 925*da0073e9SAndroid Build Coastguard Worker # Records the PG config info to the trace as node: 926*da0073e9SAndroid Build Coastguard Worker # ## process_group:init ## 927*da0073e9SAndroid Build Coastguard Worker if ( 928*da0073e9SAndroid Build Coastguard Worker self.is_registered 929*da0073e9SAndroid Build Coastguard Worker and torch.distributed.is_available() 930*da0073e9SAndroid Build Coastguard Worker and torch.distributed.is_initialized() 931*da0073e9SAndroid Build Coastguard Worker ): 932*da0073e9SAndroid Build Coastguard Worker pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info 933*da0073e9SAndroid Build Coastguard Worker torch.autograd._record_function_with_args_enter( 934*da0073e9SAndroid Build Coastguard Worker "## process_group:init ##", json.dumps(pg_config_info) 935*da0073e9SAndroid Build Coastguard Worker ) 936