1# mypy: allow-untyped-defs 2import collections 3import dataclasses 4import enum 5import itertools as it 6import logging 7from typing import ( 8 Any, 9 cast, 10 DefaultDict, 11 Dict, 12 Iterator, 13 List, 14 Optional, 15 Set, 16 Tuple, 17 Union, 18) 19from typing_extensions import Literal 20 21import torch 22from torch._C import FunctionSchema 23from torch._C._autograd import _ProfilerResult 24from torch._C._profiler import ( 25 _EventType, 26 _ExtraFields_Allocation, 27 _ExtraFields_TorchOp, 28 _ProfilerEvent, 29 _TensorMetadata, 30 RecordScope, 31) 32from torch._utils import _element_size 33from torch.profiler import _utils 34 35 36KeyAndID = Tuple["Key", int] 37TensorAndID = Tuple["TensorKey", int] 38 39log = logging.getLogger(__name__) 40 41 42class Category(enum.Enum): 43 INPUT = enum.auto() 44 TEMPORARY = enum.auto() 45 ACTIVATION = enum.auto() 46 GRADIENT = enum.auto() 47 AUTOGRAD_DETAIL = enum.auto() 48 PARAMETER = enum.auto() 49 OPTIMIZER_STATE = enum.auto() 50 51 52_CATEGORY_TO_COLORS = { 53 Category.PARAMETER: "darkgreen", 54 Category.OPTIMIZER_STATE: "goldenrod", 55 Category.INPUT: "black", 56 Category.TEMPORARY: "mediumpurple", 57 Category.ACTIVATION: "red", 58 Category.GRADIENT: "mediumblue", 59 Category.AUTOGRAD_DETAIL: "royalblue", 60 None: "grey", 61} 62 63_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)} 64 65 66class Action(enum.Enum): 67 PREEXISTING = enum.auto() 68 CREATE = enum.auto() 69 INCREMENT_VERSION = enum.auto() 70 DESTROY = enum.auto() 71 72 73_ACTION_TO_INDEX = {i: i.value for i in Action} 74 75 76@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True) 77class Key: 78 device: torch.device 79 80 81@dataclasses.dataclass 82class _Storage: 83 """Bundle storage pointer and id. 84 85 All profiling logic should use `allocation_id`, however it is useful to 86 print storage pointers for debugging and unit tests sometimes look up 87 values using the storage data pointer of a live Tensor.""" 88 89 ptr: int 90 allocation_id: int 91 92 def __repr__(self) -> str: 93 return f"{hex(self.ptr):>18} ({self.allocation_id})" 94 95 def __eq__(self, other: object) -> bool: 96 return isinstance(other, _Storage) and self.allocation_id == other.allocation_id 97 98 def __hash__(self) -> int: 99 return hash(self.allocation_id) 100 101 102@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True) 103class TensorKey(Key): 104 """Hashable identifier for a storage which has been asigned an ID. 105 106 A detailed description of Tensor IDs and why they are needed is given in 107 `torch/csrc/profiler/collection.h` when `TensorID` is declared. To 108 summarize, multiple Storage buffers can map to the same logical Tensor. 109 This dataclass is used to refer to a concrete in-memory StorageImpl of 110 a Tensor. 111 """ 112 113 id: int 114 storage: _Storage 115 116 def __repr__(self) -> str: 117 return f"id={self.id}: {repr(self.storage):<24} ({self.device})" 118 119 def __lt__(self, other: "TensorKey") -> bool: 120 return self._as_sortable < other._as_sortable 121 122 @staticmethod 123 def _make( 124 tensor_id: Optional[int], 125 storage_ptr: Optional[int], 126 allocation_id: Optional[int], 127 device: torch.device, 128 ) -> Optional["TensorKey"]: 129 if ( 130 tensor_id is not None 131 and storage_ptr is not None 132 and allocation_id is not None 133 ): 134 return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id)) 135 return None 136 137 @classmethod 138 def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]: 139 return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device) 140 141 @classmethod 142 def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]: 143 if t is not None: 144 return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device) 145 return None 146 147 @property 148 def _as_sortable(self) -> Tuple[int, int, str, int]: 149 return self.id, self.storage.allocation_id, self.device.type, self.device.index 150 151 152def _extract_parameters_and_gradients( 153 node: _ProfilerEvent, 154) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]: 155 children = node.children 156 157 # AccumulateGrad is used in the Autograd engine to handle gradient updates. 158 # There are two possible cases: 159 # 1) This is a newly created gradient Tensor. In that case there is nothing 160 # to accumulate, so autograd simply detaches the Tensor. 161 # 162 # 2) There is a preexisting gradient Tensor and we need to add the newly 163 # computed update. This is done with an in-place add (aten::add_) op. 164 # (The underscore suffix denotes "in-place".) 165 if ( 166 node.typed[0] == _EventType.TorchOp 167 and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION 168 # TODO(robieta): Move away from load bearing names 169 and node.name == "torch::autograd::AccumulateGrad" 170 and children 171 and children[0].typed[0] == _EventType.TorchOp 172 and children[0].name in ("aten::detach", "aten::add_") 173 and children[0].typed[1].inputs 174 and isinstance(children[0].typed[1].inputs[0], _TensorMetadata) 175 ): 176 yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0]) 177 178 # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer` 179 # NOTE: The values captured by the python tracer are cached; they can be 180 # used to build up labels but do not imply that a Tensor was live at 181 # a particular time. 182 elif node.typed[0] == _EventType.PyCall: 183 typed_fields = node.typed[1] 184 assert typed_fields.module is None or typed_fields.optimizer is None 185 if typed_fields.module is not None: 186 for _, p, p_grad in typed_fields.module.parameters: 187 yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) 188 189 if typed_fields.optimizer is not None: 190 for p, p_grad, _ in typed_fields.optimizer.parameters: 191 yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) 192 193 194def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]: 195 for p, p_grad in _extract_parameters_and_gradients(node): 196 if p is not None: 197 yield p 198 199 200def extract_gradients( 201 node: _ProfilerEvent, 202) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]: 203 for p, p_grad in _extract_parameters_and_gradients(node): 204 if p_grad is not None: 205 yield p, p_grad 206 207 208def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]: 209 scopes = [] 210 while event: 211 if event.typed[0] == _EventType.TorchOp: 212 scopes.append(event.typed[1].scope) 213 event = event.parent 214 return tuple(scopes) 215 216 217class SchemaMatcher: 218 """Lookup operator schema based on profiled name. 219 220 When profiling we record the operator's name but not the schema. However 221 some analysis requires that information. Fortunately we can look up 222 registered schema from the recorded name. We do not, however, record the 223 overload and so we must compare the profiled arguments with all overloads 224 to determine viable matches. 225 226 Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed 227 this code will be obsolete. 228 """ 229 230 @classmethod 231 def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]: 232 """Determine which inputs may have mutated based on function schema. 233 234 Note that we don't need to resolve down to a single schema to perform 235 this analysis. An input is mutable if it is mutable in any overload. In 236 practice, however, it is overwhelmingly common to match a single 237 overload. If we cannot find any valid schema then we must be 238 conservative and assume all inputs are mutable. 239 """ 240 mutable: Optional[List[bool]] = None 241 for schema in cls.match_schemas(t): 242 mutable = mutable or [False for _ in schema.arguments] 243 for i, arg in enumerate(schema.arguments): 244 mutable[i] |= getattr(arg.alias_info, "is_write", False) 245 246 return tuple(mutable or (None for _ in t.inputs)) 247 248 @classmethod 249 def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]: 250 signature = tuple( 251 # Tensor 252 TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) 253 # 254 # TensorList 255 else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) 256 # 257 # Scalar and uncaptured inputs. 258 else i 259 for i in t.inputs 260 ) 261 262 def matches(schema) -> bool: 263 return len(schema.arguments) == len(signature) and all( 264 cls._types_match(observed, schema_arg.type) 265 for observed, schema_arg in zip(signature, schema.arguments) 266 ) 267 268 return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s)) 269 270 @classmethod 271 def _types_match(cls, observed, schema_type) -> bool: 272 if isinstance(schema_type, torch._C.OptionalType): 273 schema_type = schema_type.getElementType() 274 return observed is None or cls._types_match(observed, schema_type) 275 276 if isinstance(schema_type, torch._C.AnyType): 277 return True 278 279 if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()): 280 return isinstance(observed, list) and all( 281 isinstance(i, TensorKey) for i in observed 282 ) 283 284 type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = ( 285 (torch._C.TensorType, TensorKey), 286 (torch._C.NoneType, type(None)), 287 (torch._C.BoolType, bool), 288 (torch._C.IntType, int), 289 (torch._C.FloatType, float), 290 (torch._C.ComplexType, complex), 291 (torch._C.NumberType, (bool, int, float, complex)), 292 ) 293 294 for jit_type, py_types in type_map: 295 if isinstance(schema_type, jit_type): 296 return isinstance(observed, py_types) 297 298 # Profiler only records a subset of possible argument types. If we 299 # reach this point then the schema must call for a type that profiler 300 # does not record. Thus, the schema can only be a match if `observed` 301 # is also None. 302 return observed is None 303 304 @staticmethod 305 def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]: 306 # TODO(robieta): 307 # _jit_get_schemas_for_operator is quite expensive. (~100us / call) 308 # Consider adding `functools.lru_cache` if that becomes an issue. 309 310 try: 311 # Schema lookup will throw if `name` is malformed. (For example, 312 # schemas must be namespaced and schema lookup will fail if name 313 # does not include "::".) We simply catch the exception and return 314 # `None` to denote that `name` cannot be an operator name. 315 # 316 # Note that record_function annotations also go through this path, 317 # so it is expected that some names will not correspond to PyTorch 318 # operators. 319 if "::" not in name: 320 return None 321 return tuple(torch._C._jit_get_schemas_for_operator(name)) 322 except RuntimeError: 323 return None 324 325 326class OpTree: 327 def __init__(self, result: _ProfilerResult) -> None: 328 self._root_nodes = result.experimental_event_tree() 329 self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns)) 330 331 def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]: 332 yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs) 333 334 @property 335 def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]: 336 return self._sorted_nodes 337 338 339class SizeMap: 340 def __init__(self, op_tree: OpTree) -> None: 341 self._values: Dict[TensorKey, int] = {} 342 343 for node in op_tree.sorted_nodes: 344 if node.typed[0] == _EventType.TorchOp: 345 for t in self._flat_tensor_inputs(node.typed[1]): 346 self._update_values(t) 347 348 elif node.typed[0] == _EventType.PyCall: 349 typed_fields = node.typed[1] 350 assert typed_fields.module is None or typed_fields.optimizer is None 351 if typed_fields.module is not None: 352 for _, p, p_grad in typed_fields.module.parameters: 353 self._update_values(p) 354 self._update_values(p_grad) 355 356 if typed_fields.optimizer is not None: 357 for p, p_grad, state in typed_fields.optimizer.parameters: 358 self._update_values(p) 359 self._update_values(p_grad) 360 for _, t in state: 361 self._update_values(t) 362 363 allocations: Dict[TensorKey, int] = {} 364 for node in op_tree.sorted_nodes: 365 if node.typed[0] == _EventType.Allocation: 366 alloc_fields = node.typed[1] 367 key = TensorKey.from_allocation(alloc_fields) 368 if key: 369 new_size = abs(alloc_fields.alloc_size) 370 prior_size = allocations.setdefault(key, new_size) 371 372 # It is possible to resize Storage in PyTorch, however we 373 # key on data pointer so most resizes will be treated as a 374 # change in storage. The one corner case that cannot be 375 # handled is `realloc` which successfully resizes the 376 # storage. At time of writing this is not done anywhere in 377 # the core PyTorch codebase. 378 if prior_size != new_size: 379 delta = f"{prior_size} vs. {new_size}" 380 log.warning("Mismatch between allocation and free: %s", delta) 381 382 self._values.update(allocations) 383 384 def _update_values(self, t: Optional[_TensorMetadata]) -> None: 385 key = TensorKey.from_tensor(t) 386 if key is not None and t is not None and t.layout == torch.strided: 387 # Scalars are represented as zero dim Tensors 388 n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1])) 389 390 num_bytes = n * _element_size(t.dtype) 391 assert num_bytes >= 0, f"{num_bytes}" 392 self._values[key] = max(self._values.get(key, 0), num_bytes) 393 394 @staticmethod 395 def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]: 396 for i in op.inputs: 397 if isinstance(i, _TensorMetadata): 398 yield i 399 elif isinstance(i, list): 400 yield from i 401 402 def __getitem__(self, key: TensorKey): 403 return self._values[key] 404 405 406@dataclasses.dataclass() 407class DataFlowEdge: 408 input_version: Optional[int] = None 409 mutated: Optional[bool] = False 410 411 @property 412 def is_allocation(self) -> bool: 413 return self.input_version is None 414 415 @property 416 def is_deletion(self) -> bool: 417 return self.mutated is None 418 419 420class DataFlowNode: 421 def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None: 422 self._event = event 423 self._graph = graph 424 self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges() 425 426 for key, edge in self._edges.items(): 427 if edge.mutated and not edge.is_allocation: 428 self._graph.bump(key) 429 430 # Make sure the version bumping behavior matches what we expect. 431 versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()} 432 assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}" 433 434 def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]: 435 subtree = tuple(_utils.traverse_dfs([self._event])) 436 437 # Start by populating edges from op inputs and outputs. 438 mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {} 439 for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp): 440 for op_input, mutable in zip( 441 op.inputs, SchemaMatcher.inputs_are_mutable(op) 442 ): 443 # Tensor 444 if isinstance(op_input, _TensorMetadata): 445 key = TensorKey.from_tensor(op_input) 446 mutable_by_key.setdefault(key, set()).add(mutable) 447 448 # TensorList 449 elif isinstance(op_input, list): 450 for op_input_i in op_input: 451 key = TensorKey.from_tensor(op_input_i) 452 mutable_by_key.setdefault(key, set()).add(mutable) 453 454 edges: DefaultDict[Optional[TensorKey], DataFlowEdge] 455 edges = collections.defaultdict(DataFlowEdge) 456 for key, mutable_set in mutable_by_key.items(): 457 if key is not None: 458 edges[key].input_version = self._graph.lookup(key) if key else -1 459 460 # We consider an op to be mutated if we encounter a schema where it 461 # is a mutable argument OR if it is ambiguous. (We never explicitly 462 # see it in any schema.) 463 mutated = (True in mutable_set) or (tuple(mutable_set) == (None,)) 464 edges[key].mutated = mutated 465 466 # Then handle deletions. Note that deleting a Tensor implicitly adds 467 # it as an input edge. 468 for i in subtree: 469 if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0: 470 key = TensorKey.from_allocation(i.typed[1]) 471 edge = edges[key] 472 assert key is None or edge.mutated is not None, f"Double delete: {key}" 473 edge.mutated = None 474 edge.input_version = self._graph.lookup(key) if key else -1 475 476 # And finally handle allocations. This step must be last, because the 477 # previous two steps optimistically add input edges. 478 for i in subtree: 479 if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0: 480 edges[TensorKey.from_allocation(i.typed[1])].input_version = None 481 482 # We don't need to sort the inputs, but it makes debugging and unit tests nicer. 483 return dict(sorted((k, v) for k, v in edges.items() if k is not None)) 484 485 @property 486 def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]: 487 return { 488 # MyPy can't see through `is_allocation` to know that 489 # `v.input_version` is not None. 490 k: (bool(v.mutated), cast(int, v.input_version)) 491 for k, v in self._edges.items() 492 if not v.is_allocation 493 } 494 495 @property 496 def outputs(self) -> Dict[TensorKey, int]: 497 return { 498 k: 0 if v.input_version is None else v.input_version + 1 499 for k, v in self._edges.items() 500 if (v.is_allocation and not v.is_deletion) or v.mutated 501 } 502 503 @property 504 def intermediates(self) -> Tuple[TensorKey, ...]: 505 return tuple( 506 k for k, v in self._edges.items() if v.is_allocation and v.is_deletion 507 ) 508 509 @property 510 def start_time(self) -> int: 511 return self._event.start_time_ns 512 513 514class DataFlowGraph: 515 def __init__(self, op_tree: OpTree) -> None: 516 self._op_tree = op_tree 517 self._leaf_events = self._extract_leaf_events(op_tree) 518 self._active_version: Dict[TensorKey, Optional[int]] = {} 519 self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events] 520 self._flow_nodes.sort(key=lambda x: x.start_time) 521 self.validate() 522 523 @property 524 def flow_nodes(self) -> Tuple[DataFlowNode, ...]: 525 return tuple(self._flow_nodes) 526 527 def validate(self): 528 # Check that each (Tensor, version) pair has a unique creation node 529 outputs: Set[Tuple[TensorKey, int]] = set() 530 for node in self.flow_nodes: 531 node_outputs = set(node.outputs.items()) 532 duplicates = outputs & node_outputs 533 assert not duplicates, f"{node._event.name} {node._edges} {duplicates}" 534 outputs |= node_outputs 535 536 # And check that `self._nodes` forms a valid topologically sorted DAG. 537 tensor_versions: Dict[TensorKey, int] = {} 538 for node in self.flow_nodes: 539 for key, (_, version) in node.inputs.items(): 540 expected = tensor_versions.get(key, 0) 541 assert expected == version, (expected, version) 542 543 for key, version in node.outputs.items(): 544 prior_version = tensor_versions.get(key, version) 545 assert version >= prior_version, (version, prior_version) 546 tensor_versions[key] = version 547 548 @property 549 def leaf_events(self) -> Tuple[_ProfilerEvent, ...]: 550 return self._leaf_events 551 552 @staticmethod 553 def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]: 554 """Partially traverse the op tree and extract top level ops. 555 556 Consider the following code: 557 ``` 558 with record_function("My annotation"): 559 x.zero_() 560 y.zero_() 561 ``` 562 563 The op tree (assuming no Autograd) will look like: 564 <Python context> 565 TorchOp: "My annotation" 566 TorchOp: zero_ 567 TorchOp: fill_ 568 TorchOp: zero_ 569 TorchOp: fill_ 570 571 The recursive structure of operator calls makes data flow unwieldy. 572 In order to simplify analysis we would like to select the highest level 573 ops to represent in the graph. In this case those are the `zero_` ops; 574 the fact that `fill_` is called is an implementation detail. We also 575 do not want to group everything under "My annotation" as this could 576 create overly coarse bundles and lose critical semantics. 577 578 To address this issue we walk over the graph and select the topmost 579 torch ops ** which match at least one operator schema **. These form 580 the leaves of the first pass through the op tree. (As well as any 581 allocations or frees which do are not part of a kernel.) These events 582 form the logical nodes in our data flow graph. 583 """ 584 585 leaf_events: List[_ProfilerEvent] = [] 586 587 def leaf_op(e: _ProfilerEvent) -> bool: 588 return e.typed[0] == _EventType.TorchOp and ( 589 e.typed[1].scope == RecordScope.BACKWARD_FUNCTION 590 or bool(SchemaMatcher.match_schemas(e.typed[1])) 591 ) 592 593 def children_fn(e: _ProfilerEvent): 594 if leaf_op(e) or e.tag == _EventType.Allocation: 595 leaf_events.append(e) 596 return [] 597 598 return e.children 599 600 for _ in op_tree.dfs(children_fn=children_fn): 601 pass 602 603 return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns)) 604 605 def lookup(self, key: TensorKey) -> int: 606 version = self._active_version.setdefault(key, 0) 607 assert version is not None 608 return version 609 610 def bump(self, key: TensorKey) -> None: 611 prior_version = self._active_version.get(key, None) 612 assert prior_version is not None 613 self._active_version[key] = prior_version + 1 614 615 def delete(self, key: TensorKey) -> None: 616 assert self._active_version.setdefault(key, 0) is not None 617 self._active_version[key] = None 618 619 620@dataclasses.dataclass 621class CategoryElement: 622 by_id: Optional[Category] = None 623 by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict) 624 by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict) 625 626 # Used by unit tests to check internals. (And consequently by 627 # MemoryProfile.lookup) This should not be used in any other capacity. 628 _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set) 629 630 631@dataclasses.dataclass 632class CategoryDict: 633 _values: DefaultDict[int, CategoryElement] = dataclasses.field( 634 default_factory=lambda: collections.defaultdict(CategoryElement) 635 ) 636 637 def set_by_id(self, key: TensorKey, category: Category) -> None: 638 self._values[key.id].by_id = category 639 self._values[key.id]._by_id_keyset.add(key) 640 641 def set_by_key(self, key: TensorKey, category: Category) -> None: 642 self._values[key.id].by_key[key] = category 643 644 def set_by_version(self, key: TensorKey, version: int, category: Category) -> None: 645 self._values[key.id].by_version[(key, version)] = category 646 647 def setdefault_by_version( 648 self, key: TensorKey, version: int, category: Category 649 ) -> None: 650 self._values[key.id].by_version.setdefault((key, version), category) 651 652 def get(self, key: Key, version: int) -> Optional[Category]: 653 if isinstance(key, Key) and not isinstance(key, TensorKey): 654 return None 655 element = self._values[key.id] 656 return ( 657 element.by_id 658 or element.by_key.get(key, None) 659 or element.by_version.get((key, version), None) 660 ) 661 662 663class MemoryProfile: 664 def __init__(self, result: _ProfilerResult) -> None: 665 self._op_tree = OpTree(result) 666 self._data_flow_graph = DataFlowGraph(self._op_tree) 667 self._size_map = SizeMap(self._op_tree) 668 self._categories = CategoryDict() 669 670 self._set_gradients_and_temporaries() 671 self._set_parameters_using_python_tracer() 672 self._set_inputs() 673 self._set_parameters_using_data_flow() 674 self._set_activations() 675 self._set_optimizer_state() 676 self._set_autograd_detail() 677 678 @property 679 def timeline(self) -> Tuple[Tuple[int, Action, KeyAndID, int], ...]: 680 output: List[Tuple[int, Action, KeyAndID, int]] = [] 681 allocation_times: Dict[Tuple[TensorKey, bool], int] = {} 682 live_unknown: Dict[Tuple[int, torch.device], Literal[True]] = {} 683 for event in self._op_tree.dfs(): 684 if event.typed[0] == _EventType.Allocation: 685 alloc_fields = event.typed[1] 686 alloc_size = alloc_fields.alloc_size 687 is_allocation = alloc_size > 0 688 t = event.start_time_ns 689 690 tkey = TensorKey.from_allocation(alloc_fields) 691 if tkey is not None: 692 allocation_times[(tkey, is_allocation)] = t 693 694 else: 695 key = Key(alloc_fields.device) 696 ptr_and_device = (alloc_fields.ptr, key.device) 697 if is_allocation: 698 if ptr_and_device in live_unknown: 699 output.append( 700 (t, Action.INCREMENT_VERSION, (key, 0), alloc_size) 701 ) 702 else: 703 live_unknown[ptr_and_device] = True 704 output.append((t, Action.CREATE, (key, 0), alloc_size)) 705 else: 706 output.append((t, Action.DESTROY, (key, 0), -alloc_size)) 707 if not live_unknown.pop(ptr_and_device, False): 708 output.append( 709 (-1, Action.PREEXISTING, (key, 0), -alloc_size) 710 ) 711 712 snapshot = self._category_snapshot() 713 last_version = dict(sorted(snapshot.keys())) 714 715 events: List[Tuple[int, Action, TensorAndID]] = [ 716 (-1, Action.PREEXISTING, (key, version)) 717 for key, version in snapshot.keys() 718 if (key, True) not in allocation_times and version == 0 719 ] 720 721 for node in self._data_flow_graph.flow_nodes: 722 for key, edge in node._edges.items(): 723 if edge.is_allocation: 724 t = allocation_times[(key, True)] 725 events.append((t, Action.CREATE, (key, 0))) 726 727 elif edge.mutated: 728 t = node._event.start_time_ns 729 version = edge.input_version 730 assert version is not None 731 events.append((t, Action.INCREMENT_VERSION, (key, version))) 732 733 if edge.is_deletion: 734 t = allocation_times[(key, False)] 735 events.append((t, Action.DESTROY, (key, last_version[key]))) 736 737 output.extend( 738 (time, action, (key, version), self._size_map[key]) 739 for time, action, (key, version) in events 740 ) 741 742 output.sort(key=lambda x: (x[0], x[1].value)) 743 return tuple(output) 744 745 def _is_gradient(self, *args, **kwargs) -> bool: 746 return self._categories.get(*args, **kwargs) == Category.GRADIENT 747 748 def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]: 749 all_tensor_versions: Set[TensorAndID] = set() 750 751 for node in self._data_flow_graph.flow_nodes: 752 all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items())) 753 all_tensor_versions.update((key, 0) for key in node.intermediates) 754 all_tensor_versions.update(node.outputs.items()) 755 756 for i in self._categories._values.values(): 757 all_tensor_versions.update((key, 0) for key in i._by_id_keyset) 758 759 return { 760 (key, version): self._categories.get(key, version) 761 for key, version in sorted(all_tensor_versions) 762 } 763 764 def _any_version_depends_on_gradient(self) -> Set[int]: 765 """Extract IDs of Tensors which depend or will depend on a gradient. 766 767 Note that this weakened definition of "depends" requires us to loop 768 over the data flow graph multiple times because it allows dependency 769 information to flow backward through edges and removes the guarantee 770 that nodes are topologically sorted. (Or indeed, even that a valid 771 topological order exists.) Put another way, we have converted an 772 acyclic data flow graph into a cyclic graph and we are attempting to 773 partition cycles involving a gradient from the rest of the graph. 774 """ 775 depends_on_gradient: Set[int] = set() 776 while True: 777 start_size = len(depends_on_gradient) 778 for node in self._data_flow_graph.flow_nodes: 779 ids = tuple( 780 key.id 781 for key, (_, version) in node.inputs.items() 782 if self._categories.get(key, version) 783 in (Category.GRADIENT, Category.PARAMETER) 784 or key.id in depends_on_gradient 785 ) 786 787 if ids: 788 depends_on_gradient.update(ids) 789 depends_on_gradient.update(key.id for key in node.outputs) 790 791 # We are guaranteed to exit because there is a finite set of 792 # TensorAndID pairs. In practice we do not expect to loop more than 793 # three times: once to identify the core parameter update loop, 794 # once to fold the first step into that loop, and a third time 795 # where no new elements are added. 796 if len(depends_on_gradient) == start_size: 797 return depends_on_gradient 798 799 def _set_gradients_and_temporaries(self) -> None: 800 """Mark Tensors which are unambiguous and simple to reason about.""" 801 802 # Gradients are straightforward to detect. We directly check the 803 # `.grad` property in the Python tracer, and we can detect any new 804 # gradient Tensors from `AccumulateGrad` ops. 805 for event in self._op_tree.dfs(): 806 for _, p_grad in extract_gradients(event): 807 self._categories.set_by_id(p_grad, Category.GRADIENT) 808 809 # Similarly, temporary Tensors are easy to identify and are useful to 810 # flag since they can make memory use "spikier" than one would 811 # otherwise expect. 812 for node in self._data_flow_graph.flow_nodes: 813 for i in node.intermediates: 814 self._categories.set_by_key(i, Category.TEMPORARY) 815 816 def _set_parameters_using_python_tracer(self) -> None: 817 for event in self._op_tree.dfs(): 818 for p in extract_parameters(event): 819 if p is not None: 820 self._categories.set_by_id(p, Category.PARAMETER) 821 822 def _set_inputs(self) -> None: 823 """Mark inputs based on which Tensors are updated using gradients. 824 825 The process for differentiating between inputs and activations is more 826 involved. Most Tensors in a training loop depend on at least one 827 gradient: parameters depend on them through updates, and activations 828 and optimizer state depend on them transitively through parameters. 829 Critically, we do not need to know which Tensors are parameters to 830 apply this method; we can simply walk the data flow graph to build the 831 set of all values which depend on a gradient and then obtain the set 832 of inputs from the conjugate set. 833 834 There is, however, one hiccup. The first time we see a parameter is 835 generally on the forward pass of the first step. We know from 836 inspection of the data flow graph that v1 of that Tensor depends on 837 a gradient (provided we profile an optimizer step), but not v0. To 838 address this problem we weaken the definition of "depends on a 839 gradient" to "any version of this Tensor depends on a gradient", 840 which in turn strengthens the criteria for the input set enough to 841 filter the activations in the forward pass of the first step.""" 842 843 # All of this analysis is predicated on using at least one training 844 # step (or parameters from the python tracer) to partition the graph. 845 # Absent that we cannot determine which Tensors are inputs and which 846 # ones are part of the model. 847 depends_on_gradient = self._any_version_depends_on_gradient() 848 849 # We only want to annotate Tensors which actually contribute to the 850 # model calculation. 851 produces_gradient: Set[TensorAndID] = set() 852 for node in reversed(self._data_flow_graph.flow_nodes): 853 tensors = {(key, version) for key, (_, version) in node.inputs.items()} 854 tensors |= node.outputs.items() 855 if any( 856 self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER) 857 or i in produces_gradient 858 for i in tensors 859 ): 860 produces_gradient |= tensors 861 862 # Don't include Tensors created in the backward pass, as these are 863 # generally Autograd implementation details rather than proper inputs. 864 input_candidates = produces_gradient.copy() 865 for node in self._data_flow_graph.flow_nodes: 866 if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): 867 input_candidates -= set(node.outputs.items()) 868 869 for key, version in input_candidates: 870 if key.id not in depends_on_gradient: 871 self._categories.setdefault_by_version(key, version, Category.INPUT) 872 873 def _set_parameters_using_data_flow(self) -> None: 874 """Deduce which Tensors are parameters. 875 876 Consider the following code for the step of SGD with momentum 877 (nesterov=False), where `d_p` is the gradient of `param` and `buf` is 878 the momentum buffer. 879 ``` 880 buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 881 d_p = buf 882 param.add_(d_p, alpha=-lr) 883 ``` 884 Both `param` and `buf` take a gradient and perform an in-place update. 885 886 The python tracer will inspect calls to `nn.Module.forward` and 887 `optim.Optimizer.step` to extract parameter and optimizer state 888 respectively (including parameters), so this is generally a non-issue. 889 890 However as a fallback we can also exploit several properties of 891 parameters to distinguish them from other model state. 892 893 First, they are directly used in the forward pass. (At this point we 894 haven't established which parts of the graph correspond to the forward 895 pass but we can deduce enough to suffice.) Some mutable state such as 896 batch norm moving averages also contribute to the forward pass, but 897 optimizer state does not. 898 899 Second, a parameter is by definition used to compute at least one 900 gradient and depends on at least one gradient. 901 """ 902 snapshot = self._category_snapshot() 903 904 # Determine which Tensors might be parameters based on forward pass 905 # data flow. Note this these are only candidates; we filter nodes that 906 # we know are part of the backward pass but that doesn't guarantee that 907 # they are part of the forward pass. 908 candidate_parameters: Set[TensorAndID] = set() 909 candidate_fwd_tensors: Set[TensorAndID] = { 910 i for i, category in snapshot.items() if category == Category.INPUT 911 } 912 913 for node in self._data_flow_graph.flow_nodes: 914 inputs = {(key, value) for key, (_, value) in node.inputs.items()} 915 if ( 916 # Don't check nodes in the backward pass. 917 RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) 918 and not any(self._is_gradient(*i) for i in inputs) 919 and not any(self._is_gradient(*i) for i in node.outputs.items()) 920 # 921 # and only check nodes which depend on an input. 922 and candidate_fwd_tensors.intersection(inputs) 923 ): 924 candidate_fwd_tensors |= node.outputs.items() 925 candidate_parameters |= inputs.difference(candidate_fwd_tensors) 926 927 # Require that each parameter eventually contributes to the value of a gradient 928 used_for_gradient: Set[TensorAndID] = set() 929 for node in reversed(self._data_flow_graph.flow_nodes): 930 if any( 931 self._is_gradient(*i) or i in used_for_gradient 932 for i in node.outputs.items() 933 ): 934 used_for_gradient.update( 935 (key, version) for key, (_, version) in node.inputs.items() 936 ) 937 candidate_parameters.intersection_update(used_for_gradient) 938 939 # and depends on a gradient. 940 parameter_keys = {key.id for key, _ in candidate_parameters} 941 parameter_keys &= self._any_version_depends_on_gradient() 942 943 for key, _ in snapshot.keys(): 944 if key.id in parameter_keys: 945 self._categories.set_by_id(key, Category.PARAMETER) 946 947 def _set_activations(self) -> None: 948 """Flood the graph to identify activations.""" 949 950 required = {Category.INPUT, Category.ACTIVATION} 951 also_allowed = {Category.PARAMETER, Category.TEMPORARY} 952 for node in self._data_flow_graph.flow_nodes: 953 inputs = {(key, value) for key, (_, value) in node.inputs.items()} 954 input_categories = {self._categories.get(*i) for i in inputs} 955 956 if ( 957 (input_categories & required) 958 and not (input_categories - (required | also_allowed)) 959 # 960 # Stop filling when we reach the backward pass. 961 and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) 962 ): 963 for i in node.outputs.items(): 964 self._categories.setdefault_by_version(*i, Category.ACTIVATION) 965 966 def _set_optimizer_state(self) -> None: 967 for event in self._op_tree.dfs(): 968 if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer: 969 parameters = event.typed[1].optimizer.parameters 970 for _, t in it.chain(*[state for _, _, state in parameters]): 971 key = TensorKey.from_tensor(t) 972 if key is not None: 973 self._categories.set_by_id(key, Category.OPTIMIZER_STATE) 974 975 def _set_autograd_detail(self): 976 prior = {None, Category.AUTOGRAD_DETAIL} 977 for node in self._data_flow_graph.flow_nodes: 978 if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): 979 for key, version in node.outputs.items(): 980 if version == 0 or self._categories.get(key, version - 1) in prior: 981 self._categories.setdefault_by_version( 982 key, version, Category.AUTOGRAD_DETAIL 983 ) 984 985 986class MemoryProfileTimeline: 987 def __init__(self, memory_profile): 988 """The minimum representation of the memory profile timeline 989 includes the memory timeline and categories. The timeline 990 consists of [timestamp, action, (TensorKey, version), numbytes] 991 elements, to denote any actions (pre-existing, create, destroy, 992 or increment_version) that occurred to a specific Tensor for a 993 chunk of memory. The categories help map each (TensorKey, 994 version) pair into a category.""" 995 self.timeline = memory_profile.timeline 996 self.categories = memory_profile._categories 997 998 def _coalesce_timeline(self, device_str): 999 """Convert the memory timeline and categories into a memory plot 1000 consisting of timestamps and their respective sizes by category 1001 for a given device. 1002 1003 Input: device 1004 Output: [timestamps, sizes by category] 1005 """ 1006 device = torch.device(device_str) 1007 times: List[int] = [] 1008 sizes: List[List[int]] = [] 1009 1010 def update(key, version, delta): 1011 category = ( 1012 self.categories.get(key, version) 1013 if isinstance(key, TensorKey) 1014 else None 1015 ) 1016 index = _CATEGORY_TO_INDEX[category] + 1 1017 sizes[-1][index] += int(delta) 1018 1019 t_min = -1 1020 for t, action, (key, version), numbytes in self.timeline: 1021 if key.device != device: 1022 continue 1023 1024 # Convert timestamps from ns to us, to match trace events. 1025 if t != -1: 1026 t = int(t / 1000) 1027 1028 # Save the smallest timestamp to populate pre-existing allocs. 1029 if t_min == -1 or (t < t_min and t > 0): 1030 t_min = t 1031 1032 # Handle timestep 1033 if len(times) == 0: 1034 times.append(t) 1035 sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX]) 1036 1037 elif t != times[-1]: 1038 times.append(t) 1039 sizes.append(sizes[-1].copy()) 1040 1041 # Handle memory and categories 1042 if action in (Action.PREEXISTING, Action.CREATE): 1043 update(key, version, numbytes) 1044 1045 elif action == Action.INCREMENT_VERSION: 1046 update(key, version, -numbytes) 1047 update(key, version + 1, numbytes) 1048 1049 elif action == Action.DESTROY: 1050 update(key, version, -numbytes) 1051 1052 else: 1053 raise ValueError(f"Unknown action: {action}") 1054 1055 times = [t_min if t < 0 else t for t in times] 1056 return times, sizes 1057 1058 def export_memory_timeline(self, path, device_str) -> None: 1059 """Saves the memory timeline as [times, sizes by category] 1060 as a JSON formatted file to the given path for the given 1061 device.""" 1062 times, sizes = self._coalesce_timeline(device_str) 1063 # TODO: Write a faster serialize (orjson not available in CI) 1064 import json 1065 1066 with open(path, "w") as f: 1067 json.dump([times, sizes], f) 1068 1069 def export_memory_timeline_raw(self, path, device_str) -> None: 1070 """Saves the memory timeline as raw memory event tuples in the 1071 form of (timestamp, action, numbytes, category) 1072 as a JSON formatted file to the given path for the given 1073 device.""" 1074 device = torch.device(device_str) 1075 raw_events: List[Tuple[int, int, int, int]] = [] 1076 1077 def get_category_index(key, version): 1078 category = ( 1079 self.categories.get(key, version) 1080 if isinstance(key, TensorKey) 1081 else None 1082 ) 1083 return _CATEGORY_TO_INDEX[category] 1084 1085 for t, action, (key, version), numbytes in self.timeline: 1086 if key.device != device: 1087 continue 1088 1089 if action in (Action.PREEXISTING, Action.CREATE): 1090 raw_events.append( 1091 ( 1092 t, 1093 _ACTION_TO_INDEX[action], 1094 numbytes, 1095 get_category_index(key, version), 1096 ) 1097 ) 1098 1099 elif action == Action.INCREMENT_VERSION: 1100 raw_events.append( 1101 ( 1102 t, 1103 _ACTION_TO_INDEX[action], 1104 -numbytes, 1105 get_category_index(key, version), 1106 ) 1107 ) 1108 raw_events.append( 1109 ( 1110 t, 1111 _ACTION_TO_INDEX[action], 1112 numbytes, 1113 get_category_index(key, version + 1), 1114 ) 1115 ) 1116 1117 elif action == Action.DESTROY: 1118 raw_events.append( 1119 ( 1120 t, 1121 _ACTION_TO_INDEX[action], 1122 -numbytes, 1123 get_category_index(key, version), 1124 ) 1125 ) 1126 1127 else: 1128 raise ValueError(f"Unknown action: {action}") 1129 1130 import json 1131 1132 with open(path, "w") as f: 1133 json.dump(raw_events, f) 1134 1135 def export_memory_timeline_html( 1136 self, path, device_str, figsize=(20, 12), title=None 1137 ) -> None: 1138 """Exports the memory timeline as an HTML file which contains 1139 the memory timeline plot embedded as a PNG file.""" 1140 # Check if user has matplotlib installed, return gracefully if not. 1141 import importlib.util 1142 1143 matplotlib_spec = importlib.util.find_spec("matplotlib") 1144 if matplotlib_spec is None: 1145 print( 1146 "export_memory_timeline_html failed because matplotlib was not found." 1147 ) 1148 return 1149 1150 from base64 import b64encode 1151 from os import remove 1152 from tempfile import NamedTemporaryFile 1153 1154 import matplotlib.pyplot as plt 1155 import numpy as np 1156 1157 mt = self._coalesce_timeline(device_str) 1158 times, sizes = np.array(mt[0]), np.array(mt[1]) 1159 # For this timeline, start at 0 to match Chrome traces. 1160 t_min = min(times) 1161 times -= t_min 1162 stacked = np.cumsum(sizes, axis=1) / 1024**3 1163 device = torch.device(device_str) 1164 max_memory_allocated = torch.cuda.max_memory_allocated(device) 1165 max_memory_reserved = torch.cuda.max_memory_reserved(device) 1166 1167 # Plot memory timeline as stacked data 1168 fig = plt.figure(figsize=figsize, dpi=80) 1169 axes = fig.gca() 1170 for category, color in _CATEGORY_TO_COLORS.items(): 1171 i = _CATEGORY_TO_INDEX[category] 1172 axes.fill_between( 1173 times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7 1174 ) 1175 fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS]) 1176 # Usually training steps are in magnitude of ms. 1177 axes.set_xlabel("Time (ms)") 1178 axes.set_ylabel("Memory (GB)") 1179 title = "\n\n".join( 1180 ([title] if title else []) 1181 + [ 1182 f"Max memory allocated: {max_memory_allocated/(1024**3):.2f} GiB \n" 1183 f"Max memory reserved: {max_memory_reserved/(1024**3):.2f} GiB" 1184 ] 1185 ) 1186 axes.set_title(title) 1187 1188 # Embed the memory timeline image into the HTML file 1189 tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False) 1190 tmpfile.close() 1191 fig.savefig(tmpfile.name, format="png") 1192 1193 with open(tmpfile.name, "rb") as tmp: 1194 encoded = b64encode(tmp.read()).decode("utf-8") 1195 html = f"""<html> 1196<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head> 1197<body> 1198 <img src='data:image/png;base64,{encoded}'> 1199</body> 1200</html>""" 1201 1202 with open(path, "w") as f: 1203 f.write(html) 1204 remove(tmpfile.name) 1205