1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 4*da0073e9SAndroid Build Coastguard Workerfrom time import perf_counter_ns 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterable, List, Optional 6*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 10*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _get_privateuse1_backend_name 11*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import _ExperimentalConfig 12*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import ( 13*da0073e9SAndroid Build Coastguard Worker _disable_profiler, 14*da0073e9SAndroid Build Coastguard Worker _enable_profiler, 15*da0073e9SAndroid Build Coastguard Worker _kineto_step, 16*da0073e9SAndroid Build Coastguard Worker _prepare_profiler, 17*da0073e9SAndroid Build Coastguard Worker _ProfilerResult, 18*da0073e9SAndroid Build Coastguard Worker _supported_activities, 19*da0073e9SAndroid Build Coastguard Worker _toggle_collection_dynamic, 20*da0073e9SAndroid Build Coastguard Worker DeviceType, 21*da0073e9SAndroid Build Coastguard Worker kineto_available, 22*da0073e9SAndroid Build Coastguard Worker ProfilerActivity, 23*da0073e9SAndroid Build Coastguard Worker ProfilerConfig, 24*da0073e9SAndroid Build Coastguard Worker ProfilerState, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_util import ( 27*da0073e9SAndroid Build Coastguard Worker _filter_name, 28*da0073e9SAndroid Build Coastguard Worker _filter_stack_entry, 29*da0073e9SAndroid Build Coastguard Worker _rewrite_name, 30*da0073e9SAndroid Build Coastguard Worker EventList, 31*da0073e9SAndroid Build Coastguard Worker FunctionEvent, 32*da0073e9SAndroid Build Coastguard Worker MEMORY_EVENT_NAME, 33*da0073e9SAndroid Build Coastguard Worker MemRecordsAcc, 34*da0073e9SAndroid Build Coastguard Worker OUT_OF_MEMORY_EVENT_NAME, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Workerfrom torch.futures import Future 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker__all__ = [ 40*da0073e9SAndroid Build Coastguard Worker "profile", 41*da0073e9SAndroid Build Coastguard Worker "record_function", 42*da0073e9SAndroid Build Coastguard Worker "emit_itt", 43*da0073e9SAndroid Build Coastguard Worker "emit_nvtx", 44*da0073e9SAndroid Build Coastguard Worker "load_nvprof", 45*da0073e9SAndroid Build Coastguard Worker "EnforceUnique", 46*da0073e9SAndroid Build Coastguard Worker "parse_nvprof_trace", 47*da0073e9SAndroid Build Coastguard Worker "KinetoStepTracker", 48*da0073e9SAndroid Build Coastguard Worker "EventList", 49*da0073e9SAndroid Build Coastguard Worker "FunctionEvent", 50*da0073e9SAndroid Build Coastguard Worker "MemRecordsAcc", 51*da0073e9SAndroid Build Coastguard Worker] 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workertry: 54*da0073e9SAndroid Build Coastguard Worker # Available in Python >= 3.2 55*da0073e9SAndroid Build Coastguard Worker from contextlib import ContextDecorator as _ContextDecorator 56*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 57*da0073e9SAndroid Build Coastguard Worker import functools 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker class _ContextDecorator: # type: ignore[no-redef] 60*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 61*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 64*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def __call__(self, func): 67*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 68*da0073e9SAndroid Build Coastguard Worker def wrapped(*args, **kwargs): 69*da0073e9SAndroid Build Coastguard Worker with self: 70*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker return wrapped 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker# global python state - whether profiler is currently enabled 76*da0073e9SAndroid Build Coastguard Worker# useful for fast python checks to reduce latency 77*da0073e9SAndroid Build Coastguard Worker_is_profiler_enabled: bool = False 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef _set_is_profiler_enabled(enable: bool): 81*da0073e9SAndroid Build Coastguard Worker global _is_profiler_enabled 82*da0073e9SAndroid Build Coastguard Worker _is_profiler_enabled = enable 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef _run_on_profiler_start(): 86*da0073e9SAndroid Build Coastguard Worker _set_is_profiler_enabled(True) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerdef _run_on_profiler_stop(): 90*da0073e9SAndroid Build Coastguard Worker _set_is_profiler_enabled(False) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker@dataclass 94*da0073e9SAndroid Build Coastguard Workerclass _ProfilerStats: 95*da0073e9SAndroid Build Coastguard Worker "Profiler timing and stats used by developers to catch issues/regressions" 96*da0073e9SAndroid Build Coastguard Worker profiling_window_duration_sec: float = 0 97*da0073e9SAndroid Build Coastguard Worker number_of_events: int = 0 98*da0073e9SAndroid Build Coastguard Worker profiler_prepare_call_duration_us: int = 0 99*da0073e9SAndroid Build Coastguard Worker profiler_enable_call_duration_us: int = 0 100*da0073e9SAndroid Build Coastguard Worker profiler_disable_call_duration_us: int = 0 101*da0073e9SAndroid Build Coastguard Worker parse_kineto_call_duration_us: int = 0 102*da0073e9SAndroid Build Coastguard Worker function_events_build_tree_call_duration_us: int = 0 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerclass profile: 106*da0073e9SAndroid Build Coastguard Worker """Context manager that manages autograd profiler state and holds a summary of results. 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker Under the hood it just records events of functions being executed in C++ and 109*da0073e9SAndroid Build Coastguard Worker exposes those events to Python. You can wrap any code into it and it will 110*da0073e9SAndroid Build Coastguard Worker only report runtime of PyTorch functions. 111*da0073e9SAndroid Build Coastguard Worker Note: profiler is thread local and is automatically propagated into the async tasks 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker Args: 114*da0073e9SAndroid Build Coastguard Worker enabled (bool, optional): Setting this to False makes this context manager a no-op. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker use_cuda (bool, optional): Enables timing of CUDA events as well 117*da0073e9SAndroid Build Coastguard Worker using the cudaEvent API. (will be deprecated) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker use_device (str, optional): Enables timing of device events. 120*da0073e9SAndroid Build Coastguard Worker Adds approximately 4us of overhead to each tensor operation when use cuda. 121*da0073e9SAndroid Build Coastguard Worker The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'. 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker record_shapes (bool, optional): If shapes recording is set, information 124*da0073e9SAndroid Build Coastguard Worker about input dimensions will be collected. This allows one to see which 125*da0073e9SAndroid Build Coastguard Worker dimensions have been used under the hood and further group by them 126*da0073e9SAndroid Build Coastguard Worker using prof.key_averages(group_by_input_shape=True). Please note that 127*da0073e9SAndroid Build Coastguard Worker shape recording might skew your profiling data. It is recommended to 128*da0073e9SAndroid Build Coastguard Worker use separate runs with and without shape recording to validate the timing. 129*da0073e9SAndroid Build Coastguard Worker Most likely the skew will be negligible for bottom most events (in a case 130*da0073e9SAndroid Build Coastguard Worker of nested function calls). But for higher level functions the total 131*da0073e9SAndroid Build Coastguard Worker self cpu time might be artificially increased because of the shape 132*da0073e9SAndroid Build Coastguard Worker collection. 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker with_flops (bool, optional): If with_flops is set, the profiler will estimate 135*da0073e9SAndroid Build Coastguard Worker the FLOPs (floating point operations) value using the operator's input shape. 136*da0073e9SAndroid Build Coastguard Worker This allows one to estimate the hardware performance. Currently, 137*da0073e9SAndroid Build Coastguard Worker this option only works for the matrix multiplication and 2D convolution operators. 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker profile_memory (bool, optional): track tensor memory allocation/deallocation. 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker with_stack (bool, optional): record source information (file and line number) for the ops. 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker with_modules (bool): record module hierarchy (including function names) 144*da0073e9SAndroid Build Coastguard Worker corresponding to the callstack of the op. e.g. If module A's forward call's 145*da0073e9SAndroid Build Coastguard Worker module B's forward which contains an aten::add op, 146*da0073e9SAndroid Build Coastguard Worker then aten::add's module hierarchy is A.B 147*da0073e9SAndroid Build Coastguard Worker Note that this support exist, at the moment, only for TorchScript models 148*da0073e9SAndroid Build Coastguard Worker and not eager mode models. 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker use_kineto (bool, optional): experimental, enable profiling with Kineto profiler. 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker use_cpu (bool, optional): profile CPU events; setting to ``False`` requires 153*da0073e9SAndroid Build Coastguard Worker ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling. 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker experimental_config (_ExperimentalConfig) : A set of experimental options 156*da0073e9SAndroid Build Coastguard Worker used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker .. warning: 162*da0073e9SAndroid Build Coastguard Worker Enabling memory profiling or source attribution incurs additional profiler 163*da0073e9SAndroid Build Coastguard Worker overhead 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker .. warning: 166*da0073e9SAndroid Build Coastguard Worker This context managers should not be called recursively, i.e. no nested 167*da0073e9SAndroid Build Coastguard Worker instances are allowed 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker .. warning: 170*da0073e9SAndroid Build Coastguard Worker Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), 171*da0073e9SAndroid Build Coastguard Worker one cannot use the profiler with ``use_device = 'cuda'`` to benchmark 172*da0073e9SAndroid Build Coastguard Worker DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, 173*da0073e9SAndroid Build Coastguard Worker please use ``use_device = None`` or ``num_workers = 0``. 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker Example: 176*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 177*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) 178*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn((1, 1), requires_grad=True) 179*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.profiler.profile() as prof: 180*da0073e9SAndroid Build Coastguard Worker >>> for _ in range(100): # any normal python code, really! 181*da0073e9SAndroid Build Coastguard Worker >>> y = x ** 2 182*da0073e9SAndroid Build Coastguard Worker >>> y.backward() 183*da0073e9SAndroid Build Coastguard Worker >>> # NOTE: some columns were removed for brevity 184*da0073e9SAndroid Build Coastguard Worker >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) 185*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 186*da0073e9SAndroid Build Coastguard Worker Name Self CPU total CPU time avg Number of Calls 187*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 188*da0073e9SAndroid Build Coastguard Worker mul 32.048ms 32.048ms 200 189*da0073e9SAndroid Build Coastguard Worker pow 27.041ms 27.041ms 200 190*da0073e9SAndroid Build Coastguard Worker PowBackward0 9.727ms 55.483ms 100 191*da0073e9SAndroid Build Coastguard Worker torch::autograd::AccumulateGrad 9.148ms 9.148ms 100 192*da0073e9SAndroid Build Coastguard Worker torch::autograd::GraphRoot 691.816us 691.816us 100 193*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker """ 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def __init__( 198*da0073e9SAndroid Build Coastguard Worker self, 199*da0073e9SAndroid Build Coastguard Worker enabled=True, 200*da0073e9SAndroid Build Coastguard Worker *, 201*da0073e9SAndroid Build Coastguard Worker use_cuda=False, # Deprecated 202*da0073e9SAndroid Build Coastguard Worker use_device=None, 203*da0073e9SAndroid Build Coastguard Worker record_shapes=False, 204*da0073e9SAndroid Build Coastguard Worker with_flops=False, 205*da0073e9SAndroid Build Coastguard Worker profile_memory=False, 206*da0073e9SAndroid Build Coastguard Worker with_stack=False, 207*da0073e9SAndroid Build Coastguard Worker with_modules=False, 208*da0073e9SAndroid Build Coastguard Worker use_kineto=False, 209*da0073e9SAndroid Build Coastguard Worker use_cpu=True, 210*da0073e9SAndroid Build Coastguard Worker experimental_config=None, 211*da0073e9SAndroid Build Coastguard Worker acc_events=False, 212*da0073e9SAndroid Build Coastguard Worker ): 213*da0073e9SAndroid Build Coastguard Worker self.enabled: bool = enabled 214*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 215*da0073e9SAndroid Build Coastguard Worker return 216*da0073e9SAndroid Build Coastguard Worker self.use_cuda = use_cuda 217*da0073e9SAndroid Build Coastguard Worker if self.use_cuda: 218*da0073e9SAndroid Build Coastguard Worker warn( 219*da0073e9SAndroid Build Coastguard Worker "The attribute `use_cuda` will be deprecated soon, " 220*da0073e9SAndroid Build Coastguard Worker "please use ``use_device = 'cuda'`` instead.", 221*da0073e9SAndroid Build Coastguard Worker FutureWarning, 222*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker self.use_device: Optional[str] = "cuda" 225*da0073e9SAndroid Build Coastguard Worker else: 226*da0073e9SAndroid Build Coastguard Worker self.use_device = use_device 227*da0073e9SAndroid Build Coastguard Worker # TODO Consider changing _function_events into data structure with size cap 228*da0073e9SAndroid Build Coastguard Worker self._function_events: Optional[EventList] = None 229*da0073e9SAndroid Build Coastguard Worker self._old_function_events: Optional[EventList] = None 230*da0073e9SAndroid Build Coastguard Worker # Function event processing is done lazily 231*da0073e9SAndroid Build Coastguard Worker self._needs_processing = False 232*da0073e9SAndroid Build Coastguard Worker self.entered = False 233*da0073e9SAndroid Build Coastguard Worker self.record_shapes = record_shapes 234*da0073e9SAndroid Build Coastguard Worker self.with_flops = with_flops 235*da0073e9SAndroid Build Coastguard Worker self.record_shapes |= self.with_flops 236*da0073e9SAndroid Build Coastguard Worker self.profile_memory = profile_memory 237*da0073e9SAndroid Build Coastguard Worker self.with_stack = with_stack 238*da0073e9SAndroid Build Coastguard Worker self.with_modules = with_modules 239*da0073e9SAndroid Build Coastguard Worker self.use_cpu = use_cpu 240*da0073e9SAndroid Build Coastguard Worker self.acc_events = acc_events 241*da0073e9SAndroid Build Coastguard Worker if experimental_config is None: 242*da0073e9SAndroid Build Coastguard Worker experimental_config = _ExperimentalConfig() 243*da0073e9SAndroid Build Coastguard Worker self.experimental_config = experimental_config 244*da0073e9SAndroid Build Coastguard Worker self.kineto_results: Optional[_ProfilerResult] = None 245*da0073e9SAndroid Build Coastguard Worker self.profiling_start_time_ns = 0 246*da0073e9SAndroid Build Coastguard Worker self.profiling_end_time_ns = 0 247*da0073e9SAndroid Build Coastguard Worker self._stats = _ProfilerStats() 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker if not self.use_cpu: 250*da0073e9SAndroid Build Coastguard Worker assert ( 251*da0073e9SAndroid Build Coastguard Worker use_kineto 252*da0073e9SAndroid Build Coastguard Worker ), "Device-only events supported only with Kineto (use_kineto=True)" 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker if self.use_device is not None: 255*da0073e9SAndroid Build Coastguard Worker VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"] 256*da0073e9SAndroid Build Coastguard Worker if _get_privateuse1_backend_name() != "privateuseone": 257*da0073e9SAndroid Build Coastguard Worker VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) 258*da0073e9SAndroid Build Coastguard Worker if self.use_device not in VALID_DEVICE_OPTIONS: 259*da0073e9SAndroid Build Coastguard Worker warn(f"The {self.use_device} is not a valid device option.") 260*da0073e9SAndroid Build Coastguard Worker self.use_device = None 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker if self.use_device == "cuda" and not torch.cuda.is_available(): 263*da0073e9SAndroid Build Coastguard Worker warn("CUDA is not available, disabling CUDA profiling") 264*da0073e9SAndroid Build Coastguard Worker self.use_cuda = False 265*da0073e9SAndroid Build Coastguard Worker self.use_device = None 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker if self.use_device == "xpu" and not torch.xpu.is_available(): 268*da0073e9SAndroid Build Coastguard Worker warn("XPU is not available, disabling XPU profiling") 269*da0073e9SAndroid Build Coastguard Worker self.use_device = None 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker self.kineto_activities = set() 272*da0073e9SAndroid Build Coastguard Worker if self.use_cpu: 273*da0073e9SAndroid Build Coastguard Worker self.kineto_activities.add(ProfilerActivity.CPU) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker self.profiler_kind = ProfilerState.KINETO 276*da0073e9SAndroid Build Coastguard Worker if self.use_device == "cuda": 277*da0073e9SAndroid Build Coastguard Worker if not use_kineto or ProfilerActivity.CUDA not in _supported_activities(): 278*da0073e9SAndroid Build Coastguard Worker assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True" 279*da0073e9SAndroid Build Coastguard Worker self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK 280*da0073e9SAndroid Build Coastguard Worker else: 281*da0073e9SAndroid Build Coastguard Worker self.kineto_activities.add(ProfilerActivity.CUDA) 282*da0073e9SAndroid Build Coastguard Worker elif self.use_device == "xpu": 283*da0073e9SAndroid Build Coastguard Worker assert ( 284*da0073e9SAndroid Build Coastguard Worker use_kineto and ProfilerActivity.XPU in _supported_activities() 285*da0073e9SAndroid Build Coastguard Worker ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." 286*da0073e9SAndroid Build Coastguard Worker self.kineto_activities.add(ProfilerActivity.XPU) 287*da0073e9SAndroid Build Coastguard Worker elif self.use_device == "mtia": 288*da0073e9SAndroid Build Coastguard Worker assert ( 289*da0073e9SAndroid Build Coastguard Worker use_kineto and ProfilerActivity.MTIA in _supported_activities() 290*da0073e9SAndroid Build Coastguard Worker ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." 291*da0073e9SAndroid Build Coastguard Worker self.kineto_activities.add(ProfilerActivity.MTIA) 292*da0073e9SAndroid Build Coastguard Worker elif self.use_device is not None and self.use_device != "privateuseone": 293*da0073e9SAndroid Build Coastguard Worker if ( 294*da0073e9SAndroid Build Coastguard Worker not use_kineto 295*da0073e9SAndroid Build Coastguard Worker or ProfilerActivity.PrivateUse1 not in _supported_activities() 296*da0073e9SAndroid Build Coastguard Worker ): 297*da0073e9SAndroid Build Coastguard Worker assert ( 298*da0073e9SAndroid Build Coastguard Worker self.use_cpu 299*da0073e9SAndroid Build Coastguard Worker ), "Legacy custombackend profiling requires use_cpu=True" 300*da0073e9SAndroid Build Coastguard Worker self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK 301*da0073e9SAndroid Build Coastguard Worker else: 302*da0073e9SAndroid Build Coastguard Worker self.kineto_activities.add(ProfilerActivity.PrivateUse1) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker assert ( 305*da0073e9SAndroid Build Coastguard Worker len(self.kineto_activities) > 0 306*da0073e9SAndroid Build Coastguard Worker ), "No activities specified for the profiler" 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker def config(self): 309*da0073e9SAndroid Build Coastguard Worker return ProfilerConfig( 310*da0073e9SAndroid Build Coastguard Worker self.profiler_kind, 311*da0073e9SAndroid Build Coastguard Worker self.record_shapes, 312*da0073e9SAndroid Build Coastguard Worker self.profile_memory, 313*da0073e9SAndroid Build Coastguard Worker self.with_stack, 314*da0073e9SAndroid Build Coastguard Worker self.with_flops, 315*da0073e9SAndroid Build Coastguard Worker self.with_modules, 316*da0073e9SAndroid Build Coastguard Worker self.experimental_config, 317*da0073e9SAndroid Build Coastguard Worker ) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 320*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 321*da0073e9SAndroid Build Coastguard Worker return 322*da0073e9SAndroid Build Coastguard Worker if self.entered: 323*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Profiler context manager is not reentrant") 324*da0073e9SAndroid Build Coastguard Worker self._prepare_trace() 325*da0073e9SAndroid Build Coastguard Worker self._start_trace() 326*da0073e9SAndroid Build Coastguard Worker return self 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker def _prepare_trace(self): 329*da0073e9SAndroid Build Coastguard Worker self.entered = True 330*da0073e9SAndroid Build Coastguard Worker t0 = perf_counter_ns() 331*da0073e9SAndroid Build Coastguard Worker _prepare_profiler(self.config(), self.kineto_activities) 332*da0073e9SAndroid Build Coastguard Worker t1 = perf_counter_ns() 333*da0073e9SAndroid Build Coastguard Worker self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def _start_trace(self): 336*da0073e9SAndroid Build Coastguard Worker self.entered = True 337*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_start() 338*da0073e9SAndroid Build Coastguard Worker t0 = perf_counter_ns() 339*da0073e9SAndroid Build Coastguard Worker _enable_profiler(self.config(), self.kineto_activities) 340*da0073e9SAndroid Build Coastguard Worker t1 = perf_counter_ns() 341*da0073e9SAndroid Build Coastguard Worker self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000) 342*da0073e9SAndroid Build Coastguard Worker self.profiling_start_time_ns = t1 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 345*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 346*da0073e9SAndroid Build Coastguard Worker return 347*da0073e9SAndroid Build Coastguard Worker if self.use_device and hasattr(torch, self.use_device): 348*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, self.use_device) 349*da0073e9SAndroid Build Coastguard Worker if hasattr(device_module, "synchronize"): 350*da0073e9SAndroid Build Coastguard Worker device_module.synchronize() 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker if self._function_events and self.acc_events: 353*da0073e9SAndroid Build Coastguard Worker self._old_function_events = self._function_events 354*da0073e9SAndroid Build Coastguard Worker self._function_events = None 355*da0073e9SAndroid Build Coastguard Worker self._needs_processing = True 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker t0 = perf_counter_ns() 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker self.kineto_results = _disable_profiler() 360*da0073e9SAndroid Build Coastguard Worker t1 = perf_counter_ns() 361*da0073e9SAndroid Build Coastguard Worker self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000) 362*da0073e9SAndroid Build Coastguard Worker self.profiling_end_time_ns = t0 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_stop() 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker self._stats.profiling_window_duration_sec = ( 367*da0073e9SAndroid Build Coastguard Worker (self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker # If we plan to accumulate events we should post process the function events 371*da0073e9SAndroid Build Coastguard Worker # right away to retain the state across mulitple start/stop calls 372*da0073e9SAndroid Build Coastguard Worker if self.acc_events: 373*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 374*da0073e9SAndroid Build Coastguard Worker return False 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 377*da0073e9SAndroid Build Coastguard Worker if self._needs_processing: 378*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 379*da0073e9SAndroid Build Coastguard Worker if self._function_events is None: 380*da0073e9SAndroid Build Coastguard Worker return "<unfinished torch.autograd.profile>" 381*da0073e9SAndroid Build Coastguard Worker return repr(self._function_events) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker def __str__(self): 384*da0073e9SAndroid Build Coastguard Worker if self._needs_processing: 385*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 386*da0073e9SAndroid Build Coastguard Worker if self._function_events is None: 387*da0073e9SAndroid Build Coastguard Worker return "<unfinished torch.autograd.profile>" 388*da0073e9SAndroid Build Coastguard Worker return str(self._function_events) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker def _ensure_function_events(self): 391*da0073e9SAndroid Build Coastguard Worker """Process function events lazily if required""" 392*da0073e9SAndroid Build Coastguard Worker if self._function_events is not None: 393*da0073e9SAndroid Build Coastguard Worker return 394*da0073e9SAndroid Build Coastguard Worker self._needs_processing = False 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker t0 = perf_counter_ns() 397*da0073e9SAndroid Build Coastguard Worker parsed_results = [] 398*da0073e9SAndroid Build Coastguard Worker if self.kineto_results: 399*da0073e9SAndroid Build Coastguard Worker parsed_results = self._parse_kineto_results(self.kineto_results) 400*da0073e9SAndroid Build Coastguard Worker t1 = perf_counter_ns() 401*da0073e9SAndroid Build Coastguard Worker self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker self._function_events = EventList( 404*da0073e9SAndroid Build Coastguard Worker parsed_results, 405*da0073e9SAndroid Build Coastguard Worker use_device=self.use_device, 406*da0073e9SAndroid Build Coastguard Worker profile_memory=self.profile_memory, 407*da0073e9SAndroid Build Coastguard Worker with_flops=self.with_flops, 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker t0 = perf_counter_ns() 410*da0073e9SAndroid Build Coastguard Worker self._function_events._build_tree() 411*da0073e9SAndroid Build Coastguard Worker t1 = perf_counter_ns() 412*da0073e9SAndroid Build Coastguard Worker self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000) 413*da0073e9SAndroid Build Coastguard Worker self._stats.number_of_events = len(self._function_events) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker if self._old_function_events and self.acc_events: 416*da0073e9SAndroid Build Coastguard Worker for evt in self._old_function_events: 417*da0073e9SAndroid Build Coastguard Worker self._function_events.append(evt) 418*da0073e9SAndroid Build Coastguard Worker self._old_function_events = None 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker if self._function_events is None: 421*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Profiler didn't finish running") 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker @property 424*da0073e9SAndroid Build Coastguard Worker def function_events(self): 425*da0073e9SAndroid Build Coastguard Worker if self._function_events is None or self._needs_processing: 426*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 427*da0073e9SAndroid Build Coastguard Worker return self._function_events 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def table( 430*da0073e9SAndroid Build Coastguard Worker self, 431*da0073e9SAndroid Build Coastguard Worker sort_by=None, 432*da0073e9SAndroid Build Coastguard Worker row_limit=100, 433*da0073e9SAndroid Build Coastguard Worker max_src_column_width=75, 434*da0073e9SAndroid Build Coastguard Worker max_name_column_width=55, 435*da0073e9SAndroid Build Coastguard Worker max_shapes_column_width=80, 436*da0073e9SAndroid Build Coastguard Worker header=None, 437*da0073e9SAndroid Build Coastguard Worker top_level_events_only=False, 438*da0073e9SAndroid Build Coastguard Worker ): 439*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 440*da0073e9SAndroid Build Coastguard Worker assert self._function_events is not None 441*da0073e9SAndroid Build Coastguard Worker return self._function_events.table( 442*da0073e9SAndroid Build Coastguard Worker sort_by=sort_by, 443*da0073e9SAndroid Build Coastguard Worker row_limit=row_limit, 444*da0073e9SAndroid Build Coastguard Worker max_src_column_width=max_src_column_width, 445*da0073e9SAndroid Build Coastguard Worker max_name_column_width=max_name_column_width, 446*da0073e9SAndroid Build Coastguard Worker max_shapes_column_width=max_shapes_column_width, 447*da0073e9SAndroid Build Coastguard Worker header=header, 448*da0073e9SAndroid Build Coastguard Worker top_level_events_only=top_level_events_only, 449*da0073e9SAndroid Build Coastguard Worker ) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker table.__doc__ = EventList.table.__doc__ 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker def export_chrome_trace(self, path): 454*da0073e9SAndroid Build Coastguard Worker """ 455*da0073e9SAndroid Build Coastguard Worker Exports the collected trace in Chrome JSON format. If kineto is enabled, only 456*da0073e9SAndroid Build Coastguard Worker last cycle in schedule is exported. 457*da0073e9SAndroid Build Coastguard Worker """ 458*da0073e9SAndroid Build Coastguard Worker if kineto_available(): 459*da0073e9SAndroid Build Coastguard Worker self.kineto_results.save(path) # type: ignore[union-attr] 460*da0073e9SAndroid Build Coastguard Worker else: 461*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 462*da0073e9SAndroid Build Coastguard Worker return self._function_events.export_chrome_trace(path) # type: ignore[union-attr] 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): 467*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 468*da0073e9SAndroid Build Coastguard Worker assert self._function_events is not None, "Expected profiling results" 469*da0073e9SAndroid Build Coastguard Worker assert self.with_stack, "export_stacks() requires with_stack=True" 470*da0073e9SAndroid Build Coastguard Worker return self._function_events.export_stacks(path, metric) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def toggle_collection_dynamic( 473*da0073e9SAndroid Build Coastguard Worker self, enabled: bool, activities: Iterable[ProfilerActivity] 474*da0073e9SAndroid Build Coastguard Worker ): 475*da0073e9SAndroid Build Coastguard Worker """ 476*da0073e9SAndroid Build Coastguard Worker Toggles the collection of activities for the current profiler instance. 477*da0073e9SAndroid Build Coastguard Worker """ 478*da0073e9SAndroid Build Coastguard Worker return _toggle_collection_dynamic(enabled, set(activities)) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): 481*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 482*da0073e9SAndroid Build Coastguard Worker assert self._function_events is not None, "Expected profiling results" 483*da0073e9SAndroid Build Coastguard Worker return self._function_events.key_averages( 484*da0073e9SAndroid Build Coastguard Worker group_by_input_shape, group_by_stack_n 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker key_averages.__doc__ = EventList.key_averages.__doc__ 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker def total_average(self): 490*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 491*da0073e9SAndroid Build Coastguard Worker assert self._function_events is not None, "Expected profiling results" 492*da0073e9SAndroid Build Coastguard Worker return self._function_events.total_average() 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker total_average.__doc__ = EventList.total_average.__doc__ 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @property 497*da0073e9SAndroid Build Coastguard Worker def self_cpu_time_total(self): 498*da0073e9SAndroid Build Coastguard Worker """Returns total time spent on CPU. 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker The total time is a sum of all self times across all the events. 501*da0073e9SAndroid Build Coastguard Worker """ 502*da0073e9SAndroid Build Coastguard Worker self._ensure_function_events() 503*da0073e9SAndroid Build Coastguard Worker assert self._function_events is not None 504*da0073e9SAndroid Build Coastguard Worker return self._function_events.self_cpu_time_total 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker def _parse_kineto_results(self, result: _ProfilerResult): 507*da0073e9SAndroid Build Coastguard Worker # result.events() has most of the events - PyTorch op-level and device-level events 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker trace_start_ns = result.trace_start_ns() 510*da0073e9SAndroid Build Coastguard Worker mem_records = [ 511*da0073e9SAndroid Build Coastguard Worker [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME 512*da0073e9SAndroid Build Coastguard Worker ] 513*da0073e9SAndroid Build Coastguard Worker oom_records = [ 514*da0073e9SAndroid Build Coastguard Worker evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME 515*da0073e9SAndroid Build Coastguard Worker ] 516*da0073e9SAndroid Build Coastguard Worker mem_records_acc = MemRecordsAcc(mem_records) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker def _cpu_memory_usage(mem_record): 519*da0073e9SAndroid Build Coastguard Worker return ( 520*da0073e9SAndroid Build Coastguard Worker mem_record.nbytes() 521*da0073e9SAndroid Build Coastguard Worker if mem_record.device_type() 522*da0073e9SAndroid Build Coastguard Worker in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] 523*da0073e9SAndroid Build Coastguard Worker else 0 524*da0073e9SAndroid Build Coastguard Worker ) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker def _device_memory_usage(mem_record): 527*da0073e9SAndroid Build Coastguard Worker return ( 528*da0073e9SAndroid Build Coastguard Worker mem_record.nbytes() 529*da0073e9SAndroid Build Coastguard Worker if mem_record.device_type() 530*da0073e9SAndroid Build Coastguard Worker in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP] 531*da0073e9SAndroid Build Coastguard Worker else 0 532*da0073e9SAndroid Build Coastguard Worker ) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker # Create and return FunctionEvent list, which contains all function events 535*da0073e9SAndroid Build Coastguard Worker # Here 2 function events are created: 536*da0073e9SAndroid Build Coastguard Worker # all_function_events contains all events associated with each kineto event from result 537*da0073e9SAndroid Build Coastguard Worker all_function_events = [] 538*da0073e9SAndroid Build Coastguard Worker # frontend_function_events contains the events in aten or torch frontend level, 539*da0073e9SAndroid Build Coastguard Worker # whose correlation id is 0 540*da0073e9SAndroid Build Coastguard Worker frontend_function_events = [] 541*da0073e9SAndroid Build Coastguard Worker device_corr_map: Dict[int, List[FunctionEvent]] = {} 542*da0073e9SAndroid Build Coastguard Worker max_evt_id = 0 543*da0073e9SAndroid Build Coastguard Worker for kineto_event in result.events(): 544*da0073e9SAndroid Build Coastguard Worker if _filter_name(kineto_event.name()): 545*da0073e9SAndroid Build Coastguard Worker continue 546*da0073e9SAndroid Build Coastguard Worker rel_start_ns = kineto_event.start_ns() - trace_start_ns 547*da0073e9SAndroid Build Coastguard Worker rel_end_ns = kineto_event.end_ns() - trace_start_ns 548*da0073e9SAndroid Build Coastguard Worker abs_end_ns = kineto_event.end_ns() 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker cpu_memory_usage = 0 551*da0073e9SAndroid Build Coastguard Worker device_memory_usage = 0 552*da0073e9SAndroid Build Coastguard Worker if kineto_event.device_type() == DeviceType.CPU: 553*da0073e9SAndroid Build Coastguard Worker # find the corresponding memory allocation events 554*da0073e9SAndroid Build Coastguard Worker for mem_record in mem_records_acc.in_interval( 555*da0073e9SAndroid Build Coastguard Worker kineto_event.start_ns() / 1000, abs_end_ns / 1000 556*da0073e9SAndroid Build Coastguard Worker ): 557*da0073e9SAndroid Build Coastguard Worker cpu_memory_usage += _cpu_memory_usage(mem_record[0]) 558*da0073e9SAndroid Build Coastguard Worker device_memory_usage += _device_memory_usage(mem_record[0]) 559*da0073e9SAndroid Build Coastguard Worker mem_record[1] = True 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker is_async = kineto_event.is_async() or ( 562*da0073e9SAndroid Build Coastguard Worker kineto_event.start_thread_id() != kineto_event.end_thread_id() 563*da0073e9SAndroid Build Coastguard Worker ) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker fe = FunctionEvent( 566*da0073e9SAndroid Build Coastguard Worker id=kineto_event.correlation_id(), 567*da0073e9SAndroid Build Coastguard Worker name=_rewrite_name(name=kineto_event.name(), with_wildcard=True), 568*da0073e9SAndroid Build Coastguard Worker trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False), 569*da0073e9SAndroid Build Coastguard Worker thread=kineto_event.start_thread_id(), 570*da0073e9SAndroid Build Coastguard Worker start_us=rel_start_ns / 1000, 571*da0073e9SAndroid Build Coastguard Worker end_us=rel_end_ns / 1000, 572*da0073e9SAndroid Build Coastguard Worker fwd_thread=kineto_event.fwd_thread_id(), 573*da0073e9SAndroid Build Coastguard Worker input_shapes=kineto_event.shapes(), 574*da0073e9SAndroid Build Coastguard Worker concrete_inputs=kineto_event.concrete_inputs(), 575*da0073e9SAndroid Build Coastguard Worker kwinputs=kineto_event.kwinputs(), 576*da0073e9SAndroid Build Coastguard Worker stack=[ 577*da0073e9SAndroid Build Coastguard Worker entry 578*da0073e9SAndroid Build Coastguard Worker for entry in kineto_event.stack() 579*da0073e9SAndroid Build Coastguard Worker if _filter_stack_entry(entry) 580*da0073e9SAndroid Build Coastguard Worker ], 581*da0073e9SAndroid Build Coastguard Worker scope=kineto_event.scope(), 582*da0073e9SAndroid Build Coastguard Worker use_device=self.use_device, 583*da0073e9SAndroid Build Coastguard Worker cpu_memory_usage=cpu_memory_usage, 584*da0073e9SAndroid Build Coastguard Worker device_memory_usage=device_memory_usage, 585*da0073e9SAndroid Build Coastguard Worker is_async=is_async, 586*da0073e9SAndroid Build Coastguard Worker sequence_nr=kineto_event.sequence_nr(), 587*da0073e9SAndroid Build Coastguard Worker device_type=kineto_event.device_type(), 588*da0073e9SAndroid Build Coastguard Worker device_index=kineto_event.device_index(), 589*da0073e9SAndroid Build Coastguard Worker device_resource_id=kineto_event.device_resource_id(), 590*da0073e9SAndroid Build Coastguard Worker flops=kineto_event.flops(), 591*da0073e9SAndroid Build Coastguard Worker is_user_annotation=kineto_event.is_user_annotation(), 592*da0073e9SAndroid Build Coastguard Worker ) 593*da0073e9SAndroid Build Coastguard Worker max_evt_id = max(max_evt_id, fe.id) 594*da0073e9SAndroid Build Coastguard Worker if fe.device_type == DeviceType.CPU and not fe.is_async: 595*da0073e9SAndroid Build Coastguard Worker if self.use_device == "privateuseone": 596*da0073e9SAndroid Build Coastguard Worker privateuse1_time = kineto_event.privateuse1_elapsed_us() 597*da0073e9SAndroid Build Coastguard Worker if privateuse1_time > 0: 598*da0073e9SAndroid Build Coastguard Worker fe.append_kernel(fe.name, fe.device_index, privateuse1_time) 599*da0073e9SAndroid Build Coastguard Worker fe.is_legacy = True 600*da0073e9SAndroid Build Coastguard Worker elif self.use_device == "cuda": 601*da0073e9SAndroid Build Coastguard Worker # Check if we have CUDA time as a fallback 602*da0073e9SAndroid Build Coastguard Worker cuda_time = kineto_event.cuda_elapsed_us() 603*da0073e9SAndroid Build Coastguard Worker if cuda_time > 0: 604*da0073e9SAndroid Build Coastguard Worker fe.append_kernel(fe.name, fe.device_index, cuda_time) 605*da0073e9SAndroid Build Coastguard Worker fe.is_legacy = True 606*da0073e9SAndroid Build Coastguard Worker all_function_events.append(fe) 607*da0073e9SAndroid Build Coastguard Worker corr_id = kineto_event.linked_correlation_id() 608*da0073e9SAndroid Build Coastguard Worker if corr_id > 0: 609*da0073e9SAndroid Build Coastguard Worker if corr_id not in device_corr_map: 610*da0073e9SAndroid Build Coastguard Worker device_corr_map[corr_id] = [] 611*da0073e9SAndroid Build Coastguard Worker device_corr_map[corr_id].append(fe) 612*da0073e9SAndroid Build Coastguard Worker elif corr_id == 0: 613*da0073e9SAndroid Build Coastguard Worker frontend_function_events.append(fe) 614*da0073e9SAndroid Build Coastguard Worker else: 615*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 616*da0073e9SAndroid Build Coastguard Worker f"Got negative correlation id {corr_id} in profiler post processing" 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker # associate device kernels and device runtime (CPU) with CPU events 620*da0073e9SAndroid Build Coastguard Worker for fe in frontend_function_events: 621*da0073e9SAndroid Build Coastguard Worker if ( 622*da0073e9SAndroid Build Coastguard Worker fe.device_type == DeviceType.CPU 623*da0073e9SAndroid Build Coastguard Worker and not fe.is_async 624*da0073e9SAndroid Build Coastguard Worker and fe.id in device_corr_map 625*da0073e9SAndroid Build Coastguard Worker ): 626*da0073e9SAndroid Build Coastguard Worker for f_evt in device_corr_map[fe.id]: 627*da0073e9SAndroid Build Coastguard Worker if ( 628*da0073e9SAndroid Build Coastguard Worker f_evt.device_type == DeviceType.CUDA 629*da0073e9SAndroid Build Coastguard Worker or f_evt.device_type == DeviceType.PrivateUse1 630*da0073e9SAndroid Build Coastguard Worker ): 631*da0073e9SAndroid Build Coastguard Worker fe.append_kernel( 632*da0073e9SAndroid Build Coastguard Worker f_evt.name, 633*da0073e9SAndroid Build Coastguard Worker f_evt.device_index, 634*da0073e9SAndroid Build Coastguard Worker f_evt.time_range.end - f_evt.time_range.start, 635*da0073e9SAndroid Build Coastguard Worker ) 636*da0073e9SAndroid Build Coastguard Worker elif f_evt.device_type == DeviceType.CPU: 637*da0073e9SAndroid Build Coastguard Worker # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated 638*da0073e9SAndroid Build Coastguard Worker # with the 'thread' of the corresponding linked PyTorch event to properly track 639*da0073e9SAndroid Build Coastguard Worker # parents and children 640*da0073e9SAndroid Build Coastguard Worker f_evt.thread = fe.thread 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker def createFunctionEventForMemoryEvents(evt): 643*da0073e9SAndroid Build Coastguard Worker rel_start_ns = evt.start_ns() - trace_start_ns 644*da0073e9SAndroid Build Coastguard Worker fe = FunctionEvent( 645*da0073e9SAndroid Build Coastguard Worker id=max_evt_id, 646*da0073e9SAndroid Build Coastguard Worker name=evt.name(), 647*da0073e9SAndroid Build Coastguard Worker trace_name=None, # not outputting in the trace 648*da0073e9SAndroid Build Coastguard Worker thread=evt.start_thread_id(), 649*da0073e9SAndroid Build Coastguard Worker start_us=rel_start_ns / 1000, 650*da0073e9SAndroid Build Coastguard Worker end_us=rel_start_ns / 1000, # no duration 651*da0073e9SAndroid Build Coastguard Worker fwd_thread=evt.start_thread_id(), 652*da0073e9SAndroid Build Coastguard Worker input_shapes=[], 653*da0073e9SAndroid Build Coastguard Worker stack=[], 654*da0073e9SAndroid Build Coastguard Worker scope=0, # RecordScope::FUNCTION 655*da0073e9SAndroid Build Coastguard Worker use_device=self.use_device, 656*da0073e9SAndroid Build Coastguard Worker cpu_memory_usage=_cpu_memory_usage(evt), 657*da0073e9SAndroid Build Coastguard Worker device_memory_usage=_device_memory_usage(evt), 658*da0073e9SAndroid Build Coastguard Worker is_async=False, 659*da0073e9SAndroid Build Coastguard Worker sequence_nr=-1, 660*da0073e9SAndroid Build Coastguard Worker device_type=DeviceType.CPU, 661*da0073e9SAndroid Build Coastguard Worker device_index=0, 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker return fe 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker # output top-level memory events 666*da0073e9SAndroid Build Coastguard Worker for mem_record in mem_records: 667*da0073e9SAndroid Build Coastguard Worker if not mem_record[1]: 668*da0073e9SAndroid Build Coastguard Worker max_evt_id += 1 669*da0073e9SAndroid Build Coastguard Worker fe = createFunctionEventForMemoryEvents(mem_record[0]) 670*da0073e9SAndroid Build Coastguard Worker all_function_events.append(fe) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker for oom_record in oom_records: 673*da0073e9SAndroid Build Coastguard Worker max_evt_id += 1 674*da0073e9SAndroid Build Coastguard Worker fe = createFunctionEventForMemoryEvents(oom_record) 675*da0073e9SAndroid Build Coastguard Worker all_function_events.append(fe) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker all_function_events.sort( 678*da0073e9SAndroid Build Coastguard Worker key=lambda evt: [evt.time_range.start, -evt.time_range.end] 679*da0073e9SAndroid Build Coastguard Worker ) 680*da0073e9SAndroid Build Coastguard Worker return all_function_events 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Workerclass record_function(_ContextDecorator): 684*da0073e9SAndroid Build Coastguard Worker """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. 685*da0073e9SAndroid Build Coastguard Worker Label will only appear if CPU activity tracing is enabled. 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker It is useful when tracing the code profile. 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker Args: 690*da0073e9SAndroid Build Coastguard Worker name (str): Label assigned to the block of code. 691*da0073e9SAndroid Build Coastguard Worker node_id (int): ID of node, for distributed profiling. Unset in 692*da0073e9SAndroid Build Coastguard Worker non-distributed cases. 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker Example: 695*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) 696*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn((1, 1), requires_grad=True) 697*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.profiler.profile() as prof: 698*da0073e9SAndroid Build Coastguard Worker ... y = x ** 2 699*da0073e9SAndroid Build Coastguard Worker ... with torch.autograd.profiler.record_function("label-z"): # label the block 700*da0073e9SAndroid Build Coastguard Worker ... z = y ** 3 701*da0073e9SAndroid Build Coastguard Worker ... y.backward() 702*da0073e9SAndroid Build Coastguard Worker ... 703*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT 704*da0073e9SAndroid Build Coastguard Worker >>> # NOTE: some columns were removed for brevity 705*da0073e9SAndroid Build Coastguard Worker >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) 706*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 707*da0073e9SAndroid Build Coastguard Worker Name Self CPU total % CPU time avg Number of Calls 708*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 709*da0073e9SAndroid Build Coastguard Worker pow 60.77% 47.470us 3 710*da0073e9SAndroid Build Coastguard Worker mul 21.73% 25.465us 2 711*da0073e9SAndroid Build Coastguard Worker PowBackward0 12.03% 121.891us 1 712*da0073e9SAndroid Build Coastguard Worker torch::autograd::AccumulateGrad 2.70% 6.324us 1 713*da0073e9SAndroid Build Coastguard Worker label-z 2.13% 12.421us 1 714*da0073e9SAndroid Build Coastguard Worker torch::autograd::GraphRoot 0.64% 1.503us 1 715*da0073e9SAndroid Build Coastguard Worker ----------------------------------- --------------- --------------- --------------- 716*da0073e9SAndroid Build Coastguard Worker Self CPU time total: 234.344us 717*da0073e9SAndroid Build Coastguard Worker CUDA time total: 0.000us 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker """ 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def __init__(self, name: str, args: Optional[str] = None): 722*da0073e9SAndroid Build Coastguard Worker self.name: str = name 723*da0073e9SAndroid Build Coastguard Worker self.args: Optional[str] = args 724*da0073e9SAndroid Build Coastguard Worker # Whether or not we should run record function's end callbacks when exiting. 725*da0073e9SAndroid Build Coastguard Worker self.run_callbacks_on_exit: bool = True 726*da0073e9SAndroid Build Coastguard Worker # TODO: TorchScript ignores standard type annotation here 727*da0073e9SAndroid Build Coastguard Worker # self.record: Optional["torch.classes.profiler._RecordFunction"] = None 728*da0073e9SAndroid Build Coastguard Worker self.record = torch.jit.annotate( 729*da0073e9SAndroid Build Coastguard Worker Optional["torch.classes.profiler._RecordFunction"], None 730*da0073e9SAndroid Build Coastguard Worker ) 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 733*da0073e9SAndroid Build Coastguard Worker self.record = torch.ops.profiler._record_function_enter_new( 734*da0073e9SAndroid Build Coastguard Worker self.name, self.args 735*da0073e9SAndroid Build Coastguard Worker ) 736*da0073e9SAndroid Build Coastguard Worker return self 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): 739*da0073e9SAndroid Build Coastguard Worker if not self.run_callbacks_on_exit: 740*da0073e9SAndroid Build Coastguard Worker return 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker # Local variable is needed by TorchScript to refine Optional[T] to T 743*da0073e9SAndroid Build Coastguard Worker record = self.record 744*da0073e9SAndroid Build Coastguard Worker assert record is not None 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker # TODO: Too slow with __torch_function__ handling enabled 747*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/76410 748*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 749*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 750*da0073e9SAndroid Build Coastguard Worker torch.ops.profiler._record_function_exit._RecordFunction(record) 751*da0073e9SAndroid Build Coastguard Worker else: 752*da0073e9SAndroid Build Coastguard Worker torch.ops.profiler._record_function_exit(record) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: 755*da0073e9SAndroid Build Coastguard Worker """Use for profiling async calls that return a future. 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker Calling this function will extend recording beyond this scope, until the future is 758*da0073e9SAndroid Build Coastguard Worker satisfied. It is useful for profiling the end to end time of asynchronous calls. 759*da0073e9SAndroid Build Coastguard Worker This function should only be called once to attach the callback onto the future, and 760*da0073e9SAndroid Build Coastguard Worker will throw if called multiple times. 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker Args: 763*da0073e9SAndroid Build Coastguard Worker fut: (torch._C.Future): future for which to schedule 764*da0073e9SAndroid Build Coastguard Worker callback for. 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker Returns: 767*da0073e9SAndroid Build Coastguard Worker A future that completes with the value of the passed in future when 768*da0073e9SAndroid Build Coastguard Worker the profiling callbacks have ran. 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker """ 771*da0073e9SAndroid Build Coastguard Worker # Throw if we have already attached a callback onto the future. 772*da0073e9SAndroid Build Coastguard Worker if not self.run_callbacks_on_exit: 773*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("_call_end_callbacks_on_future can only be called once.") 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker # We are scheduling to run this RecordFunction's end callbacks when the 776*da0073e9SAndroid Build Coastguard Worker # passed in future completes, so don't run end callbacks on exit. 777*da0073e9SAndroid Build Coastguard Worker self.run_callbacks_on_exit = False 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker # Local variable is needed by TorchScript to refine Optional[T] to T 780*da0073e9SAndroid Build Coastguard Worker record = self.record 781*da0073e9SAndroid Build Coastguard Worker assert record is not None 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker # TODO: Too slow with __torch_function__ handling enabled 784*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/76410 785*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 786*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 787*da0073e9SAndroid Build Coastguard Worker profiled_future = ( 788*da0073e9SAndroid Build Coastguard Worker torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction( 789*da0073e9SAndroid Build Coastguard Worker record, fut 790*da0073e9SAndroid Build Coastguard Worker ) 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker else: 793*da0073e9SAndroid Build Coastguard Worker profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut( 794*da0073e9SAndroid Build Coastguard Worker record, fut 795*da0073e9SAndroid Build Coastguard Worker ) 796*da0073e9SAndroid Build Coastguard Worker return profiled_future 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Workerclass emit_itt: 800*da0073e9SAndroid Build Coastguard Worker """Context manager that makes every autograd operation emit an ITT range. 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker It is useful when running the program under Intel(R) VTune Profiler:: 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker vtune <--vtune-flags> <regular command here> 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker The Instrumentation and Tracing Technology (ITT) API enables your application to generate and 807*da0073e9SAndroid Build Coastguard Worker control the collection of trace data during its execution across different Intel tools. 808*da0073e9SAndroid Build Coastguard Worker This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager, 809*da0073e9SAndroid Build Coastguard Worker you will be able to see labled ranges in Intel(R) VTune Profiler GUI. 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker .. warning: 812*da0073e9SAndroid Build Coastguard Worker This context manager should not be called recursively, i.e. at most one 813*da0073e9SAndroid Build Coastguard Worker instance should be enabled at any given time. 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker Args: 816*da0073e9SAndroid Build Coastguard Worker enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. 817*da0073e9SAndroid Build Coastguard Worker Default: ``True``. 818*da0073e9SAndroid Build Coastguard Worker record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping 819*da0073e9SAndroid Build Coastguard Worker each autograd op will append information about the sizes of Tensor arguments received 820*da0073e9SAndroid Build Coastguard Worker by that op, in the following format: 821*da0073e9SAndroid Build Coastguard Worker ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` 822*da0073e9SAndroid Build Coastguard Worker Non-tensor arguments will be represented by ``[]``. 823*da0073e9SAndroid Build Coastguard Worker Arguments will be listed in the order they are received by the backend op. 824*da0073e9SAndroid Build Coastguard Worker Please note that this order may not match the order in which those arguments were passed 825*da0073e9SAndroid Build Coastguard Worker on the Python side. Also note that shape recording may increase the overhead of itt range creation. 826*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker Example: 829*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("Undefined variables") 830*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) 831*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.profiler.emit_itt(): 832*da0073e9SAndroid Build Coastguard Worker ... model(x) 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker """ 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker def __init__(self, enabled=True, record_shapes=False): 837*da0073e9SAndroid Build Coastguard Worker self.enabled = enabled 838*da0073e9SAndroid Build Coastguard Worker self.entered = False 839*da0073e9SAndroid Build Coastguard Worker self.record_shapes = record_shapes 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 842*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 843*da0073e9SAndroid Build Coastguard Worker return 844*da0073e9SAndroid Build Coastguard Worker if self.entered: 845*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("ITT annotation context manager is not reentrant") 846*da0073e9SAndroid Build Coastguard Worker self.entered = True 847*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_start() 848*da0073e9SAndroid Build Coastguard Worker _enable_profiler( 849*da0073e9SAndroid Build Coastguard Worker ProfilerConfig( 850*da0073e9SAndroid Build Coastguard Worker ProfilerState.ITT, 851*da0073e9SAndroid Build Coastguard Worker self.record_shapes, 852*da0073e9SAndroid Build Coastguard Worker False, 853*da0073e9SAndroid Build Coastguard Worker False, 854*da0073e9SAndroid Build Coastguard Worker False, 855*da0073e9SAndroid Build Coastguard Worker False, 856*da0073e9SAndroid Build Coastguard Worker _ExperimentalConfig(), 857*da0073e9SAndroid Build Coastguard Worker ), 858*da0073e9SAndroid Build Coastguard Worker set(), 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker return self 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 863*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 864*da0073e9SAndroid Build Coastguard Worker return 865*da0073e9SAndroid Build Coastguard Worker _disable_profiler() 866*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_stop() 867*da0073e9SAndroid Build Coastguard Worker return False 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Workerclass emit_nvtx: 871*da0073e9SAndroid Build Coastguard Worker """Context manager that makes every autograd operation emit an NVTX range. 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker It is useful when running the program under nvprof:: 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker nvprof --profile-from-start off -o trace_name.prof -- <regular command here> 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker Unfortunately, there's no way to force nvprof to flush the data it collected 878*da0073e9SAndroid Build Coastguard Worker to disk, so for CUDA profiling one has to use this context manager to annotate 879*da0073e9SAndroid Build Coastguard Worker nvprof traces and wait for the process to exit before inspecting them. 880*da0073e9SAndroid Build Coastguard Worker Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or 881*da0073e9SAndroid Build Coastguard Worker :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection 882*da0073e9SAndroid Build Coastguard Worker e.g. in Python REPL. 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker .. warning: 885*da0073e9SAndroid Build Coastguard Worker This context manager should not be called recursively, i.e. at most one 886*da0073e9SAndroid Build Coastguard Worker instance should be enabled at any given time. 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker Args: 889*da0073e9SAndroid Build Coastguard Worker enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. 890*da0073e9SAndroid Build Coastguard Worker Default: ``True``. 891*da0073e9SAndroid Build Coastguard Worker record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping 892*da0073e9SAndroid Build Coastguard Worker each autograd op will append information about the sizes of Tensor arguments received 893*da0073e9SAndroid Build Coastguard Worker by that op, in the following format: 894*da0073e9SAndroid Build Coastguard Worker ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` 895*da0073e9SAndroid Build Coastguard Worker Non-tensor arguments will be represented by ``[]``. 896*da0073e9SAndroid Build Coastguard Worker Arguments will be listed in the order they are received by the backend op. 897*da0073e9SAndroid Build Coastguard Worker Please note that this order may not match the order in which those arguments were passed 898*da0073e9SAndroid Build Coastguard Worker on the Python side. Also note that shape recording may increase the overhead of nvtx range creation. 899*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker Example: 902*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("undefined variables") 903*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) 904*da0073e9SAndroid Build Coastguard Worker >>> with torch.cuda.profiler.profile(): 905*da0073e9SAndroid Build Coastguard Worker ... model(x) # Warmup CUDA memory allocator and profiler 906*da0073e9SAndroid Build Coastguard Worker ... with torch.autograd.profiler.emit_nvtx(): 907*da0073e9SAndroid Build Coastguard Worker ... model(x) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker **Forward-backward correlation** 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler, 912*da0073e9SAndroid Build Coastguard Worker correlating each backward-pass op with the corresponding forward-pass op can be difficult. 913*da0073e9SAndroid Build Coastguard Worker To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it 914*da0073e9SAndroid Build Coastguard Worker generates. 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running 917*da0073e9SAndroid Build Coastguard Worker counter, incremented each time a new backward Function object is created and stashed for backward. 918*da0073e9SAndroid Build Coastguard Worker Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that 919*da0073e9SAndroid Build Coastguard Worker if a backward Function object is created by this forward function, 920*da0073e9SAndroid Build Coastguard Worker the backward object will receive sequence number N. 921*da0073e9SAndroid Build Coastguard Worker During the backward pass, the top-level range wrapping each C++ backward Function's 922*da0073e9SAndroid Build Coastguard Worker ``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that 923*da0073e9SAndroid Build Coastguard Worker the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq`` 924*da0073e9SAndroid Build Coastguard Worker numbers in forward, you can track down which forward op created each backward Function. 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker Any functions executed during the backward pass are also decorated with ``seq=<N>``. During 927*da0073e9SAndroid Build Coastguard Worker default backward (with ``create_graph=False``) this information is irrelevant, and in fact, 928*da0073e9SAndroid Build Coastguard Worker ``N`` may simply be 0 for all such functions. Only the top-level ranges associated with 929*da0073e9SAndroid Build Coastguard Worker backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function 930*da0073e9SAndroid Build Coastguard Worker objects with the earlier forward pass. 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker **Double-backward** 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words, 935*da0073e9SAndroid Build Coastguard Worker if you are setting up for a double-backward), each function's execution during backward 936*da0073e9SAndroid Build Coastguard Worker is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects 937*da0073e9SAndroid Build Coastguard Worker to be executed later during double-backward, just as the original functions in the forward pass did. 938*da0073e9SAndroid Build Coastguard Worker The relationship between backward and double-backward is conceptually the same as the relationship 939*da0073e9SAndroid Build Coastguard Worker between forward and backward: The functions still emit current-sequence-number-tagged ranges, 940*da0073e9SAndroid Build Coastguard Worker the Function objects they create still stash those sequence numbers, and during the eventual 941*da0073e9SAndroid Build Coastguard Worker double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq`` 942*da0073e9SAndroid Build Coastguard Worker numbers, which can be compared to `seq` numbers from the backward pass. 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker .. warning: 945*da0073e9SAndroid Build Coastguard Worker The sequence number is thread-local, and some forward functions don't create an associated 946*da0073e9SAndroid Build Coastguard Worker backward Function object (instead delegating that to sub-functions further down the call chain). 947*da0073e9SAndroid Build Coastguard Worker For these reasons, the correspondence of stashed sequence numbers in 948*da0073e9SAndroid Build Coastguard Worker backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is 949*da0073e9SAndroid Build Coastguard Worker not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully 950*da0073e9SAndroid Build Coastguard Worker disambiguate which forward function created which 951*da0073e9SAndroid Build Coastguard Worker backward Function object. You may need to make a judgment based on analytic knowledge of what 952*da0073e9SAndroid Build Coastguard Worker the expected correspondence should be. 953*da0073e9SAndroid Build Coastguard Worker """ 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker def __init__(self, enabled=True, record_shapes=False): 956*da0073e9SAndroid Build Coastguard Worker self.enabled = enabled 957*da0073e9SAndroid Build Coastguard Worker self.entered = False 958*da0073e9SAndroid Build Coastguard Worker self.record_shapes = record_shapes 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 961*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 962*da0073e9SAndroid Build Coastguard Worker return 963*da0073e9SAndroid Build Coastguard Worker if self.entered: 964*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("NVTX annotation context manager is not reentrant") 965*da0073e9SAndroid Build Coastguard Worker self.entered = True 966*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 967*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_start() 968*da0073e9SAndroid Build Coastguard Worker _enable_profiler( 969*da0073e9SAndroid Build Coastguard Worker ProfilerConfig( 970*da0073e9SAndroid Build Coastguard Worker ProfilerState.NVTX, 971*da0073e9SAndroid Build Coastguard Worker self.record_shapes, 972*da0073e9SAndroid Build Coastguard Worker False, 973*da0073e9SAndroid Build Coastguard Worker False, 974*da0073e9SAndroid Build Coastguard Worker False, 975*da0073e9SAndroid Build Coastguard Worker False, 976*da0073e9SAndroid Build Coastguard Worker _ExperimentalConfig(), 977*da0073e9SAndroid Build Coastguard Worker ), 978*da0073e9SAndroid Build Coastguard Worker set(), 979*da0073e9SAndroid Build Coastguard Worker ) 980*da0073e9SAndroid Build Coastguard Worker return self 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 983*da0073e9SAndroid Build Coastguard Worker if not self.enabled: 984*da0073e9SAndroid Build Coastguard Worker return 985*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 986*da0073e9SAndroid Build Coastguard Worker _disable_profiler() 987*da0073e9SAndroid Build Coastguard Worker _run_on_profiler_stop() 988*da0073e9SAndroid Build Coastguard Worker return False 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Workerdef load_nvprof(path): 992*da0073e9SAndroid Build Coastguard Worker """Open an nvprof trace file and parses autograd annotations. 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker Args: 995*da0073e9SAndroid Build Coastguard Worker path (str): path to nvprof trace 996*da0073e9SAndroid Build Coastguard Worker """ 997*da0073e9SAndroid Build Coastguard Worker return EventList(parse_nvprof_trace(path)) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Workerclass EnforceUnique: 1001*da0073e9SAndroid Build Coastguard Worker """Raises an error if a key is seen more than once.""" 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1004*da0073e9SAndroid Build Coastguard Worker self.seen = set() 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker def see(self, *key): 1007*da0073e9SAndroid Build Coastguard Worker r""" 1008*da0073e9SAndroid Build Coastguard Worker Observe a key and raise an error if it is seen multiple times. 1009*da0073e9SAndroid Build Coastguard Worker """ 1010*da0073e9SAndroid Build Coastguard Worker if key in self.seen: 1011*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("duplicate key: " + str(key)) 1012*da0073e9SAndroid Build Coastguard Worker self.seen.add(key) 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Workerdef parse_nvprof_trace(path): 1016*da0073e9SAndroid Build Coastguard Worker import sqlite3 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker conn = sqlite3.connect(path) 1019*da0073e9SAndroid Build Coastguard Worker conn.row_factory = sqlite3.Row 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker # Parse strings table 1022*da0073e9SAndroid Build Coastguard Worker strings = {} 1023*da0073e9SAndroid Build Coastguard Worker for r in conn.execute("SELECT _id_ as id, value FROM StringTable"): 1024*da0073e9SAndroid Build Coastguard Worker strings[r["id"]] = torch._C._demangle(r["value"]) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker # First, find all functions and create FunctionEvents for them 1027*da0073e9SAndroid Build Coastguard Worker marker_query = """ 1028*da0073e9SAndroid Build Coastguard Worker SELECT 1029*da0073e9SAndroid Build Coastguard Worker start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time 1030*da0073e9SAndroid Build Coastguard Worker FROM 1031*da0073e9SAndroid Build Coastguard Worker CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end 1032*da0073e9SAndroid Build Coastguard Worker ON start.id = end.id 1033*da0073e9SAndroid Build Coastguard Worker WHERE 1034*da0073e9SAndroid Build Coastguard Worker start.name != 0 AND end.name = 0 1035*da0073e9SAndroid Build Coastguard Worker """ 1036*da0073e9SAndroid Build Coastguard Worker functions = [] 1037*da0073e9SAndroid Build Coastguard Worker functions_map = {} 1038*da0073e9SAndroid Build Coastguard Worker unique = EnforceUnique() 1039*da0073e9SAndroid Build Coastguard Worker for row in conn.execute(marker_query): 1040*da0073e9SAndroid Build Coastguard Worker unique.see(row["marker_id"]) 1041*da0073e9SAndroid Build Coastguard Worker evt = FunctionEvent( 1042*da0073e9SAndroid Build Coastguard Worker id=row["marker_id"], 1043*da0073e9SAndroid Build Coastguard Worker node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure 1044*da0073e9SAndroid Build Coastguard Worker # that pytorch doesn't crash when creating a FunctionEvent() object 1045*da0073e9SAndroid Build Coastguard Worker name=strings[row["name"]], 1046*da0073e9SAndroid Build Coastguard Worker start_us=row["start_time"], 1047*da0073e9SAndroid Build Coastguard Worker end_us=row["end_time"], 1048*da0073e9SAndroid Build Coastguard Worker thread=0, 1049*da0073e9SAndroid Build Coastguard Worker ) # TODO: find in sqlite database 1050*da0073e9SAndroid Build Coastguard Worker functions.append(evt) 1051*da0073e9SAndroid Build Coastguard Worker functions_map[evt.id] = evt 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker # Now, correlate all kernels with FunctionEvents 1054*da0073e9SAndroid Build Coastguard Worker kernel_query = """ 1055*da0073e9SAndroid Build Coastguard Worker SELECT 1056*da0073e9SAndroid Build Coastguard Worker start.id AS marker_id, start.name, start.timestamp, end.timestamp, 1057*da0073e9SAndroid Build Coastguard Worker runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end, 1058*da0073e9SAndroid Build Coastguard Worker kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name 1059*da0073e9SAndroid Build Coastguard Worker FROM 1060*da0073e9SAndroid Build Coastguard Worker CUPTI_ACTIVITY_KIND_MARKER AS start 1061*da0073e9SAndroid Build Coastguard Worker INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end 1062*da0073e9SAndroid Build Coastguard Worker ON start.id = end.id 1063*da0073e9SAndroid Build Coastguard Worker INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime 1064*da0073e9SAndroid Build Coastguard Worker ON (start.timestamp < runtime.start AND runtime.end < end.timestamp) 1065*da0073e9SAndroid Build Coastguard Worker INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel 1066*da0073e9SAndroid Build Coastguard Worker ON kernel.correlationId = runtime.correlationId 1067*da0073e9SAndroid Build Coastguard Worker """ 1068*da0073e9SAndroid Build Coastguard Worker unique = EnforceUnique() 1069*da0073e9SAndroid Build Coastguard Worker for row in conn.execute(kernel_query): 1070*da0073e9SAndroid Build Coastguard Worker unique.see(row["marker_id"], row["runtime_id"]) 1071*da0073e9SAndroid Build Coastguard Worker # 211 is cudaKernelLaunch for cuda >= 9.2 1072*da0073e9SAndroid Build Coastguard Worker assert row["cbid"] == 211 1073*da0073e9SAndroid Build Coastguard Worker evt = functions_map[row["marker_id"]] 1074*da0073e9SAndroid Build Coastguard Worker evt.append_kernel( 1075*da0073e9SAndroid Build Coastguard Worker row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"] 1076*da0073e9SAndroid Build Coastguard Worker ) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker functions.sort(key=lambda evt: evt.time_range.start) 1079*da0073e9SAndroid Build Coastguard Worker return functions 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Workerclass KinetoStepTracker: 1083*da0073e9SAndroid Build Coastguard Worker """Provides an abstraction for incrementing the step count globally. 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker Previously, we only had one place to mark that a step() has occurred 1086*da0073e9SAndroid Build Coastguard Worker in the program via pytorch profiler step(). We will now add step hooks 1087*da0073e9SAndroid Build Coastguard Worker in the Optimizer class https://github.com/pytorch/pytorch/issues/88446 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker - This could mean programs that already call profiler.step() every 1090*da0073e9SAndroid Build Coastguard Worker iteration can end up double incrementing step count. 1091*da0073e9SAndroid Build Coastguard Worker - If a model uses multiple optimizers we can also have double or more 1092*da0073e9SAndroid Build Coastguard Worker counting of the step. 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker We fix this by adding a layer of abstraction before calling step() 1095*da0073e9SAndroid Build Coastguard Worker to the kineto library. The idea is to maintain steps per requester in a dict: 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker .. code-block:: 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker { 1100*da0073e9SAndroid Build Coastguard Worker "ProfilerStep": 100, # triggered by profiler step() call 1101*da0073e9SAndroid Build Coastguard Worker "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc 1102*da0073e9SAndroid Build Coastguard Worker "Optimizer2Step": 100, 1103*da0073e9SAndroid Build Coastguard Worker } 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker To figure out the global step count just take the max of dict values (100). 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker If one of the count increments the max will go up. 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker .. code-block:: 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker { 1112*da0073e9SAndroid Build Coastguard Worker "ProfilerStep": 100, 1113*da0073e9SAndroid Build Coastguard Worker "Optimizer1Step": 101, # Optimizer1 got incremented first say 1114*da0073e9SAndroid Build Coastguard Worker "Optimizer2Step": 100, 1115*da0073e9SAndroid Build Coastguard Worker } 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker Then global step count is 101 1118*da0073e9SAndroid Build Coastguard Worker We only call the kineto step() function when global count increments. 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer 1121*da0073e9SAndroid Build Coastguard Worker for now. The result could be incorrect increments of the step count. 1122*da0073e9SAndroid Build Coastguard Worker """ 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker _current_step = 0 1125*da0073e9SAndroid Build Coastguard Worker _step_dict: Dict[str, int] = defaultdict(int) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker @classmethod 1128*da0073e9SAndroid Build Coastguard Worker def init_step_count(cls, requester: str): 1129*da0073e9SAndroid Build Coastguard Worker r""" 1130*da0073e9SAndroid Build Coastguard Worker Initialize for a given requester. 1131*da0073e9SAndroid Build Coastguard Worker """ 1132*da0073e9SAndroid Build Coastguard Worker cls._step_dict[requester] = cls._current_step 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker @classmethod 1135*da0073e9SAndroid Build Coastguard Worker def erase_step_count(cls, requester: str) -> bool: 1136*da0073e9SAndroid Build Coastguard Worker r""" 1137*da0073e9SAndroid Build Coastguard Worker Remove a given requester. 1138*da0073e9SAndroid Build Coastguard Worker """ 1139*da0073e9SAndroid Build Coastguard Worker return cls._step_dict.pop(requester, None) is not None 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker @classmethod 1142*da0073e9SAndroid Build Coastguard Worker def increment_step(cls, requester: str) -> int: 1143*da0073e9SAndroid Build Coastguard Worker """Increments the step count for the requester. 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker Additionally if the max over all step counts has incremented then 1146*da0073e9SAndroid Build Coastguard Worker trigger the _kineto_step() returns global step count 1147*da0073e9SAndroid Build Coastguard Worker """ 1148*da0073e9SAndroid Build Coastguard Worker if requester not in cls._step_dict: 1149*da0073e9SAndroid Build Coastguard Worker cls.init_step_count(requester) 1150*da0073e9SAndroid Build Coastguard Worker cls._step_dict[requester] += 1 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker new_step = max(cls._step_dict.values()) 1153*da0073e9SAndroid Build Coastguard Worker if new_step > cls._current_step: 1154*da0073e9SAndroid Build Coastguard Worker delta = new_step - cls._current_step 1155*da0073e9SAndroid Build Coastguard Worker if delta > 1: 1156*da0073e9SAndroid Build Coastguard Worker warn( 1157*da0073e9SAndroid Build Coastguard Worker "Profiler step count has increased more than 1 - " 1158*da0073e9SAndroid Build Coastguard Worker f"current_step = {cls._current_step} step dict = {cls._step_dict}" 1159*da0073e9SAndroid Build Coastguard Worker ) 1160*da0073e9SAndroid Build Coastguard Worker for _ in range(0, delta): 1161*da0073e9SAndroid Build Coastguard Worker _kineto_step() 1162*da0073e9SAndroid Build Coastguard Worker cls._current_step = new_step 1163*da0073e9SAndroid Build Coastguard Worker return cls._current_step 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker @classmethod 1166*da0073e9SAndroid Build Coastguard Worker def current_step(cls) -> int: 1167*da0073e9SAndroid Build Coastguard Worker r""" 1168*da0073e9SAndroid Build Coastguard Worker Get the latest step for any requester 1169*da0073e9SAndroid Build Coastguard Worker """ 1170*da0073e9SAndroid Build Coastguard Worker return cls._current_step 1171