1# mypy: allow-untyped-defs 2import gzip 3import json 4import os 5import shutil 6import tempfile 7from abc import ABC, abstractmethod 8from enum import Enum 9from functools import partial 10from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple 11from typing_extensions import Self 12from warnings import warn 13 14import torch 15import torch.autograd.profiler as prof 16from torch._C import _get_privateuse1_backend_name 17from torch._C._profiler import ( 18 _add_execution_trace_observer, 19 _disable_execution_trace_observer, 20 _enable_execution_trace_observer, 21 _ExperimentalConfig, 22 _remove_execution_trace_observer, 23) 24from torch.autograd import kineto_available, ProfilerActivity 25from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline 26 27 28__all__ = [ 29 "supported_activities", 30 "ProfilerAction", 31 "schedule", 32 "tensorboard_trace_handler", 33 "profile", 34 "ExecutionTraceObserver", 35] 36PROFILER_STEP_NAME = "ProfilerStep" 37 38 39def supported_activities(): 40 """ 41 Returns a set of supported profiler tracing activities. 42 43 Note: profiler uses CUPTI library to trace on-device CUDA kernels. 44 In case when CUDA is enabled but CUPTI is not available, passing 45 ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA 46 profiling code (same as in the legacy ``torch.autograd.profiler``). 47 This, in turn, results in including CUDA time in the profiler table output, 48 but not in the JSON trace. 49 """ 50 return torch.autograd._supported_activities() 51 52 53class _ITraceObserver(ABC): 54 """Abstract interface for a Trace observer. 55 This satisfies 3 methods: start, stop and cleanup""" 56 57 @abstractmethod 58 def start(self): 59 pass 60 61 @abstractmethod 62 def stop(self): 63 pass 64 65 @abstractmethod 66 def cleanup(self): 67 pass 68 69 70class _KinetoProfile: 71 """Low-level profiler wrap the autograd profile 72 73 Args: 74 activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: 75 ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, 76 ``torch.profiler.ProfilerActivity.XPU``. 77 Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA 78 or (when available) ProfilerActivity.XPU. 79 record_shapes (bool): save information about operator's input shapes. 80 profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline`` 81 for more details). 82 with_stack (bool): record source information (file and line number) for the ops. 83 with_flops (bool): use formula to estimate the FLOPS of specific operators 84 (matrix multiplication and 2D convolution). 85 with_modules (bool): record module hierarchy (including function names) 86 corresponding to the callstack of the op. e.g. If module A's forward call's 87 module B's forward which contains an aten::add op, 88 then aten::add's module hierarchy is A.B 89 Note that this support exist, at the moment, only for TorchScript models 90 and not eager mode models. 91 experimental_config (_ExperimentalConfig) : A set of experimental options 92 used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. 93 execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. 94 `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based 95 representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. 96 When this argument is included the observer start() and stop() will be called for the 97 same time window as PyTorch profiler. 98 acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles 99 100 101 .. note:: 102 This API is experimental and subject to change in the future. 103 104 Enabling shape and stack tracing results in additional overhead. 105 When record_shapes=True is specified, profiler will temporarily hold references to the tensors; 106 that may further prevent certain optimizations that depend on the reference count and introduce 107 extra tensor copies. 108 """ 109 110 def __init__( 111 self, 112 *, 113 activities: Optional[Iterable[ProfilerActivity]] = None, 114 record_shapes: bool = False, 115 profile_memory: bool = False, 116 with_stack: bool = False, 117 with_flops: bool = False, 118 with_modules: bool = False, 119 experimental_config: Optional[_ExperimentalConfig] = None, 120 execution_trace_observer: Optional[_ITraceObserver] = None, 121 acc_events: bool = False, 122 ): 123 self.activities = set(activities) if activities else supported_activities() 124 self.record_shapes = record_shapes 125 self.with_flops = with_flops 126 self.profile_memory = profile_memory 127 self.with_stack = with_stack 128 self.with_modules = with_modules 129 self.experimental_config = experimental_config 130 self.execution_trace_observer = execution_trace_observer 131 self.acc_events = acc_events 132 self.profiler: Optional[prof.profile] = None 133 self.mem_tl: Optional[MemoryProfileTimeline] = None 134 self.use_device = None 135 if ProfilerActivity.CUDA in self.activities: 136 self.use_device = "cuda" 137 elif ProfilerActivity.XPU in self.activities: 138 self.use_device = "xpu" 139 elif ProfilerActivity.MTIA in self.activities: 140 self.use_device = "mtia" 141 elif ProfilerActivity.PrivateUse1 in self.activities: 142 self.use_device = _get_privateuse1_backend_name() 143 144 # user-defined metadata to be amended to the trace 145 self.preset_metadata: Dict[str, str] = {} 146 147 def start(self): 148 self.prepare_trace() 149 self.start_trace() 150 151 def stop(self): 152 self.stop_trace() 153 154 def prepare_trace(self): 155 if (self.profiler is None) or (not self.acc_events): 156 self.profiler = prof.profile( 157 use_cpu=(ProfilerActivity.CPU in self.activities), 158 use_device=self.use_device, 159 record_shapes=self.record_shapes, 160 with_flops=self.with_flops, 161 profile_memory=self.profile_memory, 162 with_stack=self.with_stack, 163 with_modules=self.with_modules, 164 use_kineto=True, 165 experimental_config=self.experimental_config, 166 acc_events=self.acc_events, 167 ) 168 self.profiler._prepare_trace() 169 170 def start_trace(self): 171 if self.execution_trace_observer: 172 self.execution_trace_observer.start() 173 assert self.profiler is not None 174 self.profiler._start_trace() 175 176 if self.profile_memory: 177 self.add_metadata_json("profile_memory", "1") 178 if self.with_stack: 179 self.add_metadata_json("with_stack", "1") 180 if self.record_shapes: 181 self.add_metadata_json("record_shapes", "1") 182 if self.with_modules: 183 self.add_metadata_json("with_modules", "1") 184 if self.with_flops: 185 self.add_metadata_json("with_flops", "1") 186 187 if kineto_available(): 188 dist_info = self._get_distributed_info() 189 if dist_info: 190 self.add_metadata_json("distributedInfo", json.dumps(dist_info)) 191 192 if hasattr(torch, "_inductor"): 193 import torch._inductor.config as inductor_config 194 195 if inductor_config.triton.cudagraphs: 196 os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" 197 self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") 198 # FIXME: CUDA Graph does not work well with CUPTI teardown. 199 # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) 200 # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) 201 # Workaround: turn off CUPTI teardown when using CUDA Graphs. 202 os.environ["TEARDOWN_CUPTI"] = "0" 203 204 # Insert the preset user metadata to the trace 205 for k, v in self.preset_metadata.items(): 206 self.add_metadata_json(k, v) 207 208 def stop_trace(self): 209 if self.execution_trace_observer: 210 self.execution_trace_observer.stop() 211 assert self.profiler is not None 212 self.profiler.__exit__(None, None, None) 213 214 def export_chrome_trace(self, path: str): 215 """ 216 Exports the collected trace in Chrome JSON format. If kineto is enabled, only 217 last cycle in schedule is exported. 218 """ 219 assert self.profiler 220 if path.endswith(".gz"): 221 fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) 222 fp.close() 223 retvalue = self.profiler.export_chrome_trace(fp.name) 224 with open(fp.name) as fin: 225 with gzip.open(path, "wt") as fout: 226 fout.writelines(fin) 227 os.remove(fp.name) 228 return retvalue 229 else: 230 return self.profiler.export_chrome_trace(path) 231 232 def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): 233 """Save stack traces to a file 234 235 Args: 236 path (str): save stacks file to this location; 237 metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" 238 """ 239 assert self.profiler 240 return self.profiler.export_stacks(path, metric) 241 242 def toggle_collection_dynamic( 243 self, enable: bool, activities: Iterable[ProfilerActivity] 244 ): 245 """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops 246 (CPU) and CUDA activity supported in Kineto 247 248 Args: 249 activities (iterable): list of activity groups to use in profiling, supported values: 250 ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA`` 251 Examples: 252 253 .. code-block:: python 254 255 with torch.profiler.profile( 256 activities=[ 257 torch.profiler.ProfilerActivity.CPU, 258 torch.profiler.ProfilerActivity.CUDA, 259 ] 260 ) as p: 261 code_to_profile_0() 262 // turn off collection of all CUDA activity 263 p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA]) 264 code_to_profile_1() 265 // turn on collection of all CUDA activity 266 p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA]) 267 code_to_profile_2() 268 print(p.key_averages().table( 269 sort_by="self_cuda_time_total", row_limit=-1)) 270 """ 271 if not self.profiler: 272 return 273 self.profiler.toggle_collection_dynamic(enable, activities) 274 275 def key_averages( 276 self, group_by_input_shape: bool = False, group_by_stack_n: int = 0 277 ): 278 """Averages events, grouping them by operator name and (optionally) input shapes and 279 stack. 280 281 .. note:: 282 To use shape/stack functionality make sure to set record_shapes/with_stack 283 when creating profiler context manager. 284 """ 285 assert self.profiler 286 return self.profiler.key_averages(group_by_input_shape, group_by_stack_n) 287 288 def events(self): 289 """ 290 Returns the list of unaggregated profiler events, 291 to be used in the trace callback or after the profiling is finished 292 """ 293 assert self.profiler 294 return self.profiler.function_events 295 296 def add_metadata(self, key: str, value: str): 297 """ 298 Adds a user defined metadata with a string key and a string value 299 into the trace file 300 """ 301 wrapped_value = '"' + value.replace('"', '\\"') + '"' 302 torch.autograd._add_metadata_json(key, wrapped_value) 303 304 def add_metadata_json(self, key: str, value: str): 305 """ 306 Adds a user defined metadata with a string key and a valid json value 307 into the trace file 308 """ 309 torch.autograd._add_metadata_json(key, value) 310 311 def preset_metadata_json(self, key: str, value: str): 312 """ 313 Preset a user defined metadata when the profiler is not started 314 and added into the trace file later. 315 Metadata is in the format of a string key and a valid json value 316 """ 317 self.preset_metadata[key] = value 318 319 def _get_distributed_info(self): 320 import torch.distributed as dist 321 322 if not dist.is_available() or not dist.is_initialized(): 323 return None 324 325 backend = dist.get_backend() 326 dist_info = { 327 "backend": backend, 328 "rank": dist.get_rank(), 329 "world_size": dist.get_world_size(), 330 "pg_count": dist.get_pg_count(), 331 "pg_config": dist.distributed_c10d._get_all_pg_configs(), 332 } 333 if backend == "nccl": 334 nccl_version = torch.cuda.nccl.version() 335 dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) 336 return dist_info 337 338 def _memory_profile(self) -> MemoryProfile: 339 required = ("record_shapes", "profile_memory", "with_stack") 340 missing = [f"{i}=True" for i in required if not getattr(self, i)] 341 if missing: 342 raise ValueError(f"{', '.join(missing)} required for memory profiling.") 343 344 assert self.profiler is not None and self.profiler.kineto_results is not None 345 return MemoryProfile(self.profiler.kineto_results) 346 347 def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: 348 """Export memory event information from the profiler collected 349 tree for a given device, and export a timeline plot. There are 3 350 exportable files using ``export_memory_timeline``, each controlled by the 351 ``path``'s suffix. 352 353 - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline 354 plot will be embedded as a PNG file in the HTML file. 355 356 - For plot points consisting of ``[times, [sizes by category]]``, where 357 ``times`` are timestamps and ``sizes`` are memory usage for each category. 358 The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON 359 (``.json.gz``) depending on the suffix. 360 361 - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory 362 event will consist of ``(timestamp, action, numbytes, category)``, where 363 ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``, 364 and ``category`` is one of the enums from 365 ``torch.profiler._memory_profiler.Category``. 366 367 Output: Memory timeline written as gzipped JSON, JSON, or HTML. 368 """ 369 # Default to device 0, if unset. Fallback on cpu. 370 if device is None and self.use_device and self.use_device != "cuda": 371 device = self.use_device + ":0" 372 373 if device is None: 374 device = "cuda:0" if torch.cuda.is_available() else "cpu" 375 376 # Construct the memory timeline plot data 377 self.mem_tl = MemoryProfileTimeline(self._memory_profile()) 378 379 # Depending on the file suffix, save the data as json.gz or json. 380 # For html, we can embed the image into an HTML file. 381 if path.endswith(".html"): 382 self.mem_tl.export_memory_timeline_html(path, device) 383 elif path.endswith(".gz"): 384 fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) 385 fp.close() 386 if path.endswith("raw.json.gz"): 387 self.mem_tl.export_memory_timeline_raw(fp.name, device) 388 else: 389 self.mem_tl.export_memory_timeline(fp.name, device) 390 with open(fp.name) as fin: 391 with gzip.open(path, "wt") as fout: 392 fout.writelines(fin) 393 os.remove(fp.name) 394 else: 395 self.mem_tl.export_memory_timeline(path, device) 396 397 398class ProfilerAction(Enum): 399 """ 400 Profiler actions that can be taken at the specified intervals 401 """ 402 403 NONE = 0 404 WARMUP = 1 405 RECORD = 2 406 RECORD_AND_SAVE = 3 407 408 409def schedule( 410 *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0 411) -> Callable: 412 """ 413 Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip 414 the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps, 415 then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps. 416 The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that 417 the cycles will continue until the profiling is finished. 418 """ 419 420 def schedule_fn(step: int) -> ProfilerAction: 421 assert step >= 0 422 if step < skip_first: 423 return ProfilerAction.NONE 424 else: 425 step -= skip_first 426 num_steps = wait + warmup + active 427 if repeat > 0 and step / num_steps >= repeat: 428 return ProfilerAction.NONE 429 mod_step = step % num_steps 430 if mod_step < wait: 431 return ProfilerAction.NONE 432 elif mod_step < wait + warmup: 433 return ProfilerAction.WARMUP 434 else: 435 return ( 436 ProfilerAction.RECORD 437 if mod_step < num_steps - 1 438 else ProfilerAction.RECORD_AND_SAVE 439 ) 440 441 assert ( 442 wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 443 ), "Invalid profiler schedule arguments" 444 if warmup == 0: 445 warn("Profiler won't be using warmup, this can skew profiler results") 446 return schedule_fn 447 448 449def _default_schedule_fn(_: int) -> ProfilerAction: 450 """ 451 Default profiler behavior - immediately starts recording the events, 452 keeps doing it on every profiler step. 453 """ 454 return ProfilerAction.RECORD 455 456 457def tensorboard_trace_handler( 458 dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False 459): 460 """ 461 Outputs tracing files to directory of ``dir_name``, then that directory can be 462 directly delivered to tensorboard as logdir. 463 ``worker_name`` should be unique for each worker in distributed scenario, 464 it will be set to '[hostname]_[pid]' by default. 465 """ 466 import os 467 import socket 468 import time 469 470 def handler_fn(prof) -> None: 471 nonlocal worker_name 472 if not os.path.isdir(dir_name): 473 try: 474 os.makedirs(dir_name, exist_ok=True) 475 except Exception as e: 476 raise RuntimeError("Can't create directory: " + dir_name) from e 477 if not worker_name: 478 worker_name = f"{socket.gethostname()}_{os.getpid()}" 479 # Use nanosecond here to avoid naming clash when exporting the trace 480 file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json" 481 if use_gzip: 482 file_name = file_name + ".gz" 483 prof.export_chrome_trace(os.path.join(dir_name, file_name)) 484 485 return handler_fn 486 487 488class profile(_KinetoProfile): 489 """Profiler context manager. 490 491 Args: 492 activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: 493 ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, 494 ``torch.profiler.ProfilerActivity.XPU``. 495 Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA 496 or (when available) ProfilerActivity.XPU. 497 schedule (Callable): callable that takes step (int) as a single parameter and returns 498 ``ProfilerAction`` value that specifies the profiler action to perform at each step. 499 on_trace_ready (Callable): callable that is called at each step when ``schedule`` 500 returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. 501 record_shapes (bool): save information about operator's input shapes. 502 profile_memory (bool): track tensor memory allocation/deallocation. 503 with_stack (bool): record source information (file and line number) for the ops. 504 with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators 505 (matrix multiplication and 2D convolution). 506 with_modules (bool): record module hierarchy (including function names) 507 corresponding to the callstack of the op. e.g. If module A's forward call's 508 module B's forward which contains an aten::add op, 509 then aten::add's module hierarchy is A.B 510 Note that this support exist, at the moment, only for TorchScript models 511 and not eager mode models. 512 experimental_config (_ExperimentalConfig) : A set of experimental options 513 used for Kineto library features. Note, backward compatibility is not guaranteed. 514 execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. 515 `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based 516 representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. 517 When this argument is included the observer start() and stop() will be called for the 518 same time window as PyTorch profiler. See the examples section below for a code sample. 519 acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles 520 use_cuda (bool): 521 .. deprecated:: 1.8.1 522 use ``activities`` instead. 523 524 .. note:: 525 Use :func:`~torch.profiler.schedule` to generate the callable schedule. 526 Non-default schedules are useful when profiling long training jobs 527 and allow the user to obtain multiple traces at the different iterations 528 of the training process. 529 The default schedule simply records all the events continuously for the 530 duration of the context manager. 531 532 .. note:: 533 Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: 534 535 ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` 536 537 After profiling, result files can be found in the specified directory. Use the command: 538 539 ``tensorboard --logdir dir_name`` 540 541 to see the results in TensorBoard. 542 For more information, see 543 `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__ 544 545 .. note:: 546 Enabling shape and stack tracing results in additional overhead. 547 When record_shapes=True is specified, profiler will temporarily hold references to the tensors; 548 that may further prevent certain optimizations that depend on the reference count and introduce 549 extra tensor copies. 550 551 552 Examples: 553 554 .. code-block:: python 555 556 with torch.profiler.profile( 557 activities=[ 558 torch.profiler.ProfilerActivity.CPU, 559 torch.profiler.ProfilerActivity.CUDA, 560 ] 561 ) as p: 562 code_to_profile() 563 print(p.key_averages().table( 564 sort_by="self_cuda_time_total", row_limit=-1)) 565 566 Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: 567 568 .. code-block:: python 569 570 # Non-default profiler schedule allows user to turn profiler on and off 571 # on different iterations of the training loop; 572 # trace_handler is called every time a new trace becomes available 573 def trace_handler(prof): 574 print(prof.key_averages().table( 575 sort_by="self_cuda_time_total", row_limit=-1)) 576 # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") 577 578 with torch.profiler.profile( 579 activities=[ 580 torch.profiler.ProfilerActivity.CPU, 581 torch.profiler.ProfilerActivity.CUDA, 582 ], 583 584 # In this example with wait=1, warmup=1, active=2, repeat=1, 585 # profiler will skip the first step/iteration, 586 # start warming up on the second, record 587 # the third and the forth iterations, 588 # after which the trace will become available 589 # and on_trace_ready (when set) is called; 590 # the cycle repeats starting with the next step 591 592 schedule=torch.profiler.schedule( 593 wait=1, 594 warmup=1, 595 active=2, 596 repeat=1), 597 on_trace_ready=trace_handler 598 # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') 599 # used when outputting for tensorboard 600 ) as p: 601 for iter in range(N): 602 code_iteration_to_profile(iter) 603 # send a signal to the profiler that the next iteration has started 604 p.step() 605 606 The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) 607 608 .. code-block:: python 609 610 with torch.profiler.profile( 611 ... 612 execution_trace_observer=( 613 ExecutionTraceObserver().register_callback("./execution_trace.json") 614 ), 615 ) as p: 616 for iter in range(N): 617 code_iteration_to_profile(iter) 618 p.step() 619 620 You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py. 621 Note: One can also pass any object satisfying the _ITraceObserver interface. 622 """ 623 624 def __init__( 625 self, 626 *, 627 activities: Optional[Iterable[ProfilerActivity]] = None, 628 schedule: Optional[Callable[[int], ProfilerAction]] = None, 629 on_trace_ready: Optional[Callable[..., Any]] = None, 630 record_shapes: bool = False, 631 profile_memory: bool = False, 632 with_stack: bool = False, 633 with_flops: bool = False, 634 with_modules: bool = False, 635 experimental_config: Optional[_ExperimentalConfig] = None, 636 execution_trace_observer: Optional[_ITraceObserver] = None, 637 acc_events: bool = False, 638 # deprecated: 639 use_cuda: Optional[bool] = None, 640 ): 641 activities_set = set(activities) if activities else supported_activities() 642 if use_cuda is not None: 643 warn( 644 "`use_cuda` is deprecated, use `activities` argument instead", 645 FutureWarning, 646 stacklevel=2, 647 ) 648 if use_cuda: 649 activities_set.add(ProfilerActivity.CUDA) 650 elif ProfilerActivity.CUDA in activities_set: 651 activities_set.remove(ProfilerActivity.CUDA) 652 assert len(activities_set) > 0, "No valid profiler activities found" 653 654 super().__init__( 655 activities=activities, 656 record_shapes=record_shapes, 657 profile_memory=profile_memory, 658 with_stack=with_stack, 659 with_flops=with_flops, 660 with_modules=with_modules, 661 experimental_config=experimental_config, 662 execution_trace_observer=execution_trace_observer, 663 acc_events=acc_events, 664 ) 665 666 if schedule: 667 self.schedule = schedule 668 # add step markers into the trace and table view 669 self.record_steps = True 670 else: 671 self.schedule = _default_schedule_fn 672 self.record_steps = False 673 self.on_trace_ready = on_trace_ready 674 self.step_num = 0 675 self.current_action = self.schedule(self.step_num) 676 self.step_rec_fn: Optional[prof.record_function] = None 677 678 self.action_map: Dict[ 679 Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any] 680 ] = { 681 # key is (prev_action, current_action), value is action list corresponding to the state pair. 682 (ProfilerAction.NONE, ProfilerAction.NONE): [], 683 (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace], 684 (ProfilerAction.NONE, ProfilerAction.RECORD): [ 685 self.prepare_trace, 686 self.start_trace, 687 ], 688 (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [ 689 self.prepare_trace, 690 self.start_trace, 691 ], 692 (ProfilerAction.WARMUP, ProfilerAction.NONE): [ 693 partial(warn, "Incorrect schedule: WARMUP followed by NONE"), 694 self.start_trace, 695 self.stop_trace, 696 ], 697 (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [], 698 (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace], 699 (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace], 700 (ProfilerAction.RECORD, ProfilerAction.NONE): [ 701 partial(warn, "Incorrect schedule: RECORD followed by NONE"), 702 self.stop_trace, 703 ], 704 (ProfilerAction.RECORD, ProfilerAction.WARMUP): [ 705 partial(warn, "Incorrect schedule: RECORD followed by WARMUP"), 706 self.stop_trace, 707 ], 708 (ProfilerAction.RECORD, ProfilerAction.RECORD): [], 709 (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [], 710 (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [ 711 self.stop_trace, 712 self._trace_ready, 713 ], 714 (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [ 715 self.stop_trace, 716 self._trace_ready, 717 self.prepare_trace, 718 ], 719 (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [ 720 self.stop_trace, 721 self._trace_ready, 722 self.prepare_trace, 723 self.start_trace, 724 ], 725 (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [ 726 self.stop_trace, 727 self._trace_ready, 728 self.prepare_trace, 729 self.start_trace, 730 ], 731 # used for exit action 732 (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace], 733 (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready], 734 (ProfilerAction.RECORD_AND_SAVE, None): [ 735 self.stop_trace, 736 self._trace_ready, 737 ], 738 } 739 # Start tracking increments to profiler step, this will be used 740 # by Kineto 741 prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME) 742 743 def __enter__(self): 744 self.start() 745 return self 746 747 def __exit__(self, exc_type, exc_val, exc_tb): 748 self.stop() 749 prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME) 750 if self.execution_trace_observer: 751 self.execution_trace_observer.cleanup() 752 753 def start(self): 754 self._transit_action(ProfilerAction.NONE, self.current_action) 755 if self.record_steps: 756 self.step_rec_fn = prof.record_function( 757 "ProfilerStep#" + str(self.step_num) 758 ) 759 self.step_rec_fn.__enter__() 760 761 def stop(self): 762 if self.record_steps and self.step_rec_fn: 763 self.step_rec_fn.__exit__(None, None, None) 764 self._transit_action(self.current_action, None) 765 766 def step(self): 767 """ 768 Signals the profiler that the next profiling step has started. 769 """ 770 if self.record_steps and self.step_rec_fn: 771 self.step_rec_fn.__exit__(None, None, None) 772 prev_action = self.current_action 773 self.step_num += 1 774 self.current_action = self.schedule(self.step_num) 775 776 self._transit_action(prev_action, self.current_action) 777 prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME) 778 779 if self.record_steps: 780 self.step_rec_fn = prof.record_function( 781 "ProfilerStep#" + str(self.step_num) 782 ) 783 self.step_rec_fn.__enter__() 784 785 def _trace_ready(self): 786 if self.on_trace_ready: 787 self.on_trace_ready(self) 788 789 def _transit_action(self, prev_action, current_action): 790 action_list = self.action_map.get((prev_action, current_action)) 791 if action_list: 792 for action in action_list: 793 action() 794 795 def _stats(self) -> Optional[prof._ProfilerStats]: 796 if self.profiler is None: 797 return None 798 return self.profiler._stats 799 800 801class ExecutionTraceObserver(_ITraceObserver): 802 """Execution Trace Observer 803 804 Each process can have a single ExecutionTraceObserver instance. The observer 805 can be added to record function callbacks via calling register_callback() 806 explicitly. Without calling unregister_callback(), repeated calls to 807 register_callback() will not add additional observers to record function 808 callbacks. Once an ExecutionTraceObserver is created, the start() and stop() 809 methods control when the event data is recorded. 810 811 Deleting or calling unregister_callback() will remove the observer from the 812 record function callbacks, finalize the output file, and will stop 813 incurring any overheads. 814 """ 815 816 def __init__(self) -> None: 817 """ 818 Initializes the default states. 819 """ 820 self._registered = False 821 self._execution_trace_running = False 822 823 def __del__(self): 824 """ 825 Calls unregister_callback() to make sure to finalize outputs. 826 """ 827 self.unregister_callback() 828 829 def register_callback(self, output_file_path: str) -> Self: 830 """ 831 Adds ET observer to record function callbacks. The data will be 832 written to output_file_path. 833 """ 834 if not self._registered: 835 self._output_file_path = output_file_path 836 self._registered = _add_execution_trace_observer(output_file_path) 837 return self 838 839 def unregister_callback(self): 840 """ 841 Removes ET observer from record function callbacks. 842 """ 843 844 def _save_triton_kernels(): 845 # Save the kernel paths for the generated kernels 846 from torch._inductor.codecache import PyCodeCache as PyCodeCache 847 848 kernel_files = [ 849 v.__file__ 850 for v in PyCodeCache.cache.values() 851 if getattr(v, "__file__", None) is not None 852 ] 853 work_dir, file_name = os.path.split(self._output_file_path) 854 resource_dir = os.path.join( 855 work_dir, os.path.splitext(file_name)[0] + "_resources" 856 ) 857 if not os.path.exists(resource_dir): 858 os.mkdir(resource_dir) 859 860 for kernel_file in kernel_files: 861 if kernel_file is None: 862 continue 863 path, name = os.path.split(kernel_file) 864 dst = os.path.join(resource_dir, name) 865 shutil.copyfile(kernel_file, dst) 866 867 if self._registered: 868 self.stop() 869 try: 870 _save_triton_kernels() 871 except Exception as e: 872 warn(f"Execution trace failed to save kernels: {e}") 873 _remove_execution_trace_observer() 874 self._registered = False 875 876 @property 877 def is_registered(self): 878 """ 879 Returns True if the execution trace observer is registered, otherwise False. 880 """ 881 return self._registered 882 883 def is_running(self): 884 """ 885 Returns True if the observer is running, otherwise False. 886 """ 887 return self._execution_trace_running 888 889 def start(self): 890 """ 891 Starts to capture. 892 """ 893 if self._registered and not self._execution_trace_running: 894 _enable_execution_trace_observer() 895 self._execution_trace_running = True 896 self._record_pg_config() 897 898 def stop(self): 899 """ 900 Stops to capture. 901 """ 902 if self._execution_trace_running: 903 _disable_execution_trace_observer() 904 self._execution_trace_running = False 905 906 def cleanup(self): 907 """ 908 Calls unregister_callback() to make sure to finalize outputs. 909 """ 910 self.unregister_callback() 911 912 def get_output_file_path(self) -> str: 913 """ 914 Returns the output file name. 915 """ 916 if self.is_registered: 917 return self._output_file_path 918 else: 919 raise RuntimeError( 920 "A callback to the ET profiler needs to be registered " 921 "first before getting the output file path" 922 ) 923 924 def _record_pg_config(self) -> None: 925 # Records the PG config info to the trace as node: 926 # ## process_group:init ## 927 if ( 928 self.is_registered 929 and torch.distributed.is_available() 930 and torch.distributed.is_initialized() 931 ): 932 pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info 933 torch.autograd._record_function_with_args_enter( 934 "## process_group:init ##", json.dumps(pg_config_info) 935 ) 936