xref: /aosp_15_r20/external/pytorch/torch/profiler/_memory_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import dataclasses
4import enum
5import itertools as it
6import logging
7from typing import (
8    Any,
9    cast,
10    DefaultDict,
11    Dict,
12    Iterator,
13    List,
14    Optional,
15    Set,
16    Tuple,
17    Union,
18)
19from typing_extensions import Literal
20
21import torch
22from torch._C import FunctionSchema
23from torch._C._autograd import _ProfilerResult
24from torch._C._profiler import (
25    _EventType,
26    _ExtraFields_Allocation,
27    _ExtraFields_TorchOp,
28    _ProfilerEvent,
29    _TensorMetadata,
30    RecordScope,
31)
32from torch._utils import _element_size
33from torch.profiler import _utils
34
35
36KeyAndID = Tuple["Key", int]
37TensorAndID = Tuple["TensorKey", int]
38
39log = logging.getLogger(__name__)
40
41
42class Category(enum.Enum):
43    INPUT = enum.auto()
44    TEMPORARY = enum.auto()
45    ACTIVATION = enum.auto()
46    GRADIENT = enum.auto()
47    AUTOGRAD_DETAIL = enum.auto()
48    PARAMETER = enum.auto()
49    OPTIMIZER_STATE = enum.auto()
50
51
52_CATEGORY_TO_COLORS = {
53    Category.PARAMETER: "darkgreen",
54    Category.OPTIMIZER_STATE: "goldenrod",
55    Category.INPUT: "black",
56    Category.TEMPORARY: "mediumpurple",
57    Category.ACTIVATION: "red",
58    Category.GRADIENT: "mediumblue",
59    Category.AUTOGRAD_DETAIL: "royalblue",
60    None: "grey",
61}
62
63_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)}
64
65
66class Action(enum.Enum):
67    PREEXISTING = enum.auto()
68    CREATE = enum.auto()
69    INCREMENT_VERSION = enum.auto()
70    DESTROY = enum.auto()
71
72
73_ACTION_TO_INDEX = {i: i.value for i in Action}
74
75
76@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True)
77class Key:
78    device: torch.device
79
80
81@dataclasses.dataclass
82class _Storage:
83    """Bundle storage pointer and id.
84
85    All profiling logic should use `allocation_id`, however it is useful to
86    print storage pointers for debugging and unit tests sometimes look up
87    values using the storage data pointer of a live Tensor."""
88
89    ptr: int
90    allocation_id: int
91
92    def __repr__(self) -> str:
93        return f"{hex(self.ptr):>18} ({self.allocation_id})"
94
95    def __eq__(self, other: object) -> bool:
96        return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
97
98    def __hash__(self) -> int:
99        return hash(self.allocation_id)
100
101
102@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
103class TensorKey(Key):
104    """Hashable identifier for a storage which has been asigned an ID.
105
106    A detailed description of Tensor IDs and why they are needed is given in
107    `torch/csrc/profiler/collection.h` when `TensorID` is declared. To
108    summarize, multiple Storage buffers can map to the same logical Tensor.
109    This dataclass is used to refer to a concrete in-memory StorageImpl of
110    a Tensor.
111    """
112
113    id: int
114    storage: _Storage
115
116    def __repr__(self) -> str:
117        return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
118
119    def __lt__(self, other: "TensorKey") -> bool:
120        return self._as_sortable < other._as_sortable
121
122    @staticmethod
123    def _make(
124        tensor_id: Optional[int],
125        storage_ptr: Optional[int],
126        allocation_id: Optional[int],
127        device: torch.device,
128    ) -> Optional["TensorKey"]:
129        if (
130            tensor_id is not None
131            and storage_ptr is not None
132            and allocation_id is not None
133        ):
134            return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id))
135        return None
136
137    @classmethod
138    def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]:
139        return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device)
140
141    @classmethod
142    def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
143        if t is not None:
144            return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
145        return None
146
147    @property
148    def _as_sortable(self) -> Tuple[int, int, str, int]:
149        return self.id, self.storage.allocation_id, self.device.type, self.device.index
150
151
152def _extract_parameters_and_gradients(
153    node: _ProfilerEvent,
154) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]:
155    children = node.children
156
157    # AccumulateGrad is used in the Autograd engine to handle gradient updates.
158    # There are two possible cases:
159    # 1) This is a newly created gradient Tensor. In that case there is nothing
160    #    to accumulate, so autograd simply detaches the Tensor.
161    #
162    # 2) There is a preexisting gradient Tensor and we need to add the newly
163    #    computed update. This is done with an in-place add (aten::add_) op.
164    #    (The underscore suffix denotes "in-place".)
165    if (
166        node.typed[0] == _EventType.TorchOp
167        and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
168        # TODO(robieta): Move away from load bearing names
169        and node.name == "torch::autograd::AccumulateGrad"
170        and children
171        and children[0].typed[0] == _EventType.TorchOp
172        and children[0].name in ("aten::detach", "aten::add_")
173        and children[0].typed[1].inputs
174        and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
175    ):
176        yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0])
177
178    # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
179    # NOTE: The values captured by the python tracer are cached; they can be
180    #       used to build up labels but do not imply that a Tensor was live at
181    #       a particular time.
182    elif node.typed[0] == _EventType.PyCall:
183        typed_fields = node.typed[1]
184        assert typed_fields.module is None or typed_fields.optimizer is None
185        if typed_fields.module is not None:
186            for _, p, p_grad in typed_fields.module.parameters:
187                yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
188
189        if typed_fields.optimizer is not None:
190            for p, p_grad, _ in typed_fields.optimizer.parameters:
191                yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
192
193
194def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
195    for p, p_grad in _extract_parameters_and_gradients(node):
196        if p is not None:
197            yield p
198
199
200def extract_gradients(
201    node: _ProfilerEvent,
202) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
203    for p, p_grad in _extract_parameters_and_gradients(node):
204        if p_grad is not None:
205            yield p, p_grad
206
207
208def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]:
209    scopes = []
210    while event:
211        if event.typed[0] == _EventType.TorchOp:
212            scopes.append(event.typed[1].scope)
213        event = event.parent
214    return tuple(scopes)
215
216
217class SchemaMatcher:
218    """Lookup operator schema based on profiled name.
219
220    When profiling we record the operator's name but not the schema. However
221    some analysis requires that information. Fortunately we can look up
222    registered schema from the recorded name. We do not, however, record the
223    overload and so we must compare the profiled arguments with all overloads
224    to determine viable matches.
225
226    Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed
227    this code will be obsolete.
228    """
229
230    @classmethod
231    def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]:
232        """Determine which inputs may have mutated based on function schema.
233
234        Note that we don't need to resolve down to a single schema to perform
235        this analysis. An input is mutable if it is mutable in any overload. In
236        practice, however, it is overwhelmingly common to match a single
237        overload. If we cannot find any valid schema then we must be
238        conservative and assume all inputs are mutable.
239        """
240        mutable: Optional[List[bool]] = None
241        for schema in cls.match_schemas(t):
242            mutable = mutable or [False for _ in schema.arguments]
243            for i, arg in enumerate(schema.arguments):
244                mutable[i] |= getattr(arg.alias_info, "is_write", False)
245
246        return tuple(mutable or (None for _ in t.inputs))
247
248    @classmethod
249    def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]:
250        signature = tuple(
251            # Tensor
252            TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
253            #
254            # TensorList
255            else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
256            #
257            # Scalar and uncaptured inputs.
258            else i
259            for i in t.inputs
260        )
261
262        def matches(schema) -> bool:
263            return len(schema.arguments) == len(signature) and all(
264                cls._types_match(observed, schema_arg.type)
265                for observed, schema_arg in zip(signature, schema.arguments)
266            )
267
268        return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
269
270    @classmethod
271    def _types_match(cls, observed, schema_type) -> bool:
272        if isinstance(schema_type, torch._C.OptionalType):
273            schema_type = schema_type.getElementType()
274            return observed is None or cls._types_match(observed, schema_type)
275
276        if isinstance(schema_type, torch._C.AnyType):
277            return True
278
279        if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()):
280            return isinstance(observed, list) and all(
281                isinstance(i, TensorKey) for i in observed
282            )
283
284        type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = (
285            (torch._C.TensorType, TensorKey),
286            (torch._C.NoneType, type(None)),
287            (torch._C.BoolType, bool),
288            (torch._C.IntType, int),
289            (torch._C.FloatType, float),
290            (torch._C.ComplexType, complex),
291            (torch._C.NumberType, (bool, int, float, complex)),
292        )
293
294        for jit_type, py_types in type_map:
295            if isinstance(schema_type, jit_type):
296                return isinstance(observed, py_types)
297
298        # Profiler only records a subset of possible argument types. If we
299        # reach this point then the schema must call for a type that profiler
300        # does not record. Thus, the schema can only be a match if `observed`
301        # is also None.
302        return observed is None
303
304    @staticmethod
305    def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]:
306        # TODO(robieta):
307        #   _jit_get_schemas_for_operator is quite expensive. (~100us / call)
308        #   Consider adding `functools.lru_cache` if that becomes an issue.
309
310        try:
311            # Schema lookup will throw if `name` is malformed. (For example,
312            # schemas must be namespaced and schema lookup will fail if name
313            # does not include "::".) We simply catch the exception and return
314            # `None` to denote that `name` cannot be an operator name.
315            #
316            # Note that record_function annotations also go through this path,
317            # so it is expected that some names will not correspond to PyTorch
318            # operators.
319            if "::" not in name:
320                return None
321            return tuple(torch._C._jit_get_schemas_for_operator(name))
322        except RuntimeError:
323            return None
324
325
326class OpTree:
327    def __init__(self, result: _ProfilerResult) -> None:
328        self._root_nodes = result.experimental_event_tree()
329        self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns))
330
331    def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]:
332        yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs)
333
334    @property
335    def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]:
336        return self._sorted_nodes
337
338
339class SizeMap:
340    def __init__(self, op_tree: OpTree) -> None:
341        self._values: Dict[TensorKey, int] = {}
342
343        for node in op_tree.sorted_nodes:
344            if node.typed[0] == _EventType.TorchOp:
345                for t in self._flat_tensor_inputs(node.typed[1]):
346                    self._update_values(t)
347
348            elif node.typed[0] == _EventType.PyCall:
349                typed_fields = node.typed[1]
350                assert typed_fields.module is None or typed_fields.optimizer is None
351                if typed_fields.module is not None:
352                    for _, p, p_grad in typed_fields.module.parameters:
353                        self._update_values(p)
354                        self._update_values(p_grad)
355
356                if typed_fields.optimizer is not None:
357                    for p, p_grad, state in typed_fields.optimizer.parameters:
358                        self._update_values(p)
359                        self._update_values(p_grad)
360                        for _, t in state:
361                            self._update_values(t)
362
363        allocations: Dict[TensorKey, int] = {}
364        for node in op_tree.sorted_nodes:
365            if node.typed[0] == _EventType.Allocation:
366                alloc_fields = node.typed[1]
367                key = TensorKey.from_allocation(alloc_fields)
368                if key:
369                    new_size = abs(alloc_fields.alloc_size)
370                    prior_size = allocations.setdefault(key, new_size)
371
372                    # It is possible to resize Storage in PyTorch, however we
373                    # key on data pointer so most resizes will be treated as a
374                    # change in storage. The one corner case that cannot be
375                    # handled is `realloc` which successfully resizes the
376                    # storage. At time of writing this is not done anywhere in
377                    # the core PyTorch codebase.
378                    if prior_size != new_size:
379                        delta = f"{prior_size} vs. {new_size}"
380                        log.warning("Mismatch between allocation and free: %s", delta)
381
382        self._values.update(allocations)
383
384    def _update_values(self, t: Optional[_TensorMetadata]) -> None:
385        key = TensorKey.from_tensor(t)
386        if key is not None and t is not None and t.layout == torch.strided:
387            # Scalars are represented as zero dim Tensors
388            n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
389
390            num_bytes = n * _element_size(t.dtype)
391            assert num_bytes >= 0, f"{num_bytes}"
392            self._values[key] = max(self._values.get(key, 0), num_bytes)
393
394    @staticmethod
395    def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]:
396        for i in op.inputs:
397            if isinstance(i, _TensorMetadata):
398                yield i
399            elif isinstance(i, list):
400                yield from i
401
402    def __getitem__(self, key: TensorKey):
403        return self._values[key]
404
405
406@dataclasses.dataclass()
407class DataFlowEdge:
408    input_version: Optional[int] = None
409    mutated: Optional[bool] = False
410
411    @property
412    def is_allocation(self) -> bool:
413        return self.input_version is None
414
415    @property
416    def is_deletion(self) -> bool:
417        return self.mutated is None
418
419
420class DataFlowNode:
421    def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
422        self._event = event
423        self._graph = graph
424        self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
425
426        for key, edge in self._edges.items():
427            if edge.mutated and not edge.is_allocation:
428                self._graph.bump(key)
429
430        # Make sure the version bumping behavior matches what we expect.
431        versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
432        assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
433
434    def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
435        subtree = tuple(_utils.traverse_dfs([self._event]))
436
437        # Start by populating edges from op inputs and outputs.
438        mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
439        for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
440            for op_input, mutable in zip(
441                op.inputs, SchemaMatcher.inputs_are_mutable(op)
442            ):
443                # Tensor
444                if isinstance(op_input, _TensorMetadata):
445                    key = TensorKey.from_tensor(op_input)
446                    mutable_by_key.setdefault(key, set()).add(mutable)
447
448                # TensorList
449                elif isinstance(op_input, list):
450                    for op_input_i in op_input:
451                        key = TensorKey.from_tensor(op_input_i)
452                        mutable_by_key.setdefault(key, set()).add(mutable)
453
454        edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
455        edges = collections.defaultdict(DataFlowEdge)
456        for key, mutable_set in mutable_by_key.items():
457            if key is not None:
458                edges[key].input_version = self._graph.lookup(key) if key else -1
459
460                # We consider an op to be mutated if we encounter a schema where it
461                # is a mutable argument OR if it is ambiguous. (We never explicitly
462                # see it in any schema.)
463                mutated = (True in mutable_set) or (tuple(mutable_set) == (None,))
464                edges[key].mutated = mutated
465
466        # Then handle deletions. Note that deleting a Tensor implicitly adds
467        # it as an input edge.
468        for i in subtree:
469            if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0:
470                key = TensorKey.from_allocation(i.typed[1])
471                edge = edges[key]
472                assert key is None or edge.mutated is not None, f"Double delete: {key}"
473                edge.mutated = None
474                edge.input_version = self._graph.lookup(key) if key else -1
475
476        # And finally handle allocations. This step must be last, because the
477        # previous two steps optimistically add input edges.
478        for i in subtree:
479            if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0:
480                edges[TensorKey.from_allocation(i.typed[1])].input_version = None
481
482        # We don't need to sort the inputs, but it makes debugging and unit tests nicer.
483        return dict(sorted((k, v) for k, v in edges.items() if k is not None))
484
485    @property
486    def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]:
487        return {
488            # MyPy can't see through `is_allocation` to know that
489            # `v.input_version` is not None.
490            k: (bool(v.mutated), cast(int, v.input_version))
491            for k, v in self._edges.items()
492            if not v.is_allocation
493        }
494
495    @property
496    def outputs(self) -> Dict[TensorKey, int]:
497        return {
498            k: 0 if v.input_version is None else v.input_version + 1
499            for k, v in self._edges.items()
500            if (v.is_allocation and not v.is_deletion) or v.mutated
501        }
502
503    @property
504    def intermediates(self) -> Tuple[TensorKey, ...]:
505        return tuple(
506            k for k, v in self._edges.items() if v.is_allocation and v.is_deletion
507        )
508
509    @property
510    def start_time(self) -> int:
511        return self._event.start_time_ns
512
513
514class DataFlowGraph:
515    def __init__(self, op_tree: OpTree) -> None:
516        self._op_tree = op_tree
517        self._leaf_events = self._extract_leaf_events(op_tree)
518        self._active_version: Dict[TensorKey, Optional[int]] = {}
519        self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
520        self._flow_nodes.sort(key=lambda x: x.start_time)
521        self.validate()
522
523    @property
524    def flow_nodes(self) -> Tuple[DataFlowNode, ...]:
525        return tuple(self._flow_nodes)
526
527    def validate(self):
528        # Check that each (Tensor, version) pair has a unique creation node
529        outputs: Set[Tuple[TensorKey, int]] = set()
530        for node in self.flow_nodes:
531            node_outputs = set(node.outputs.items())
532            duplicates = outputs & node_outputs
533            assert not duplicates, f"{node._event.name} {node._edges} {duplicates}"
534            outputs |= node_outputs
535
536        # And check that `self._nodes` forms a valid topologically sorted DAG.
537        tensor_versions: Dict[TensorKey, int] = {}
538        for node in self.flow_nodes:
539            for key, (_, version) in node.inputs.items():
540                expected = tensor_versions.get(key, 0)
541                assert expected == version, (expected, version)
542
543            for key, version in node.outputs.items():
544                prior_version = tensor_versions.get(key, version)
545                assert version >= prior_version, (version, prior_version)
546                tensor_versions[key] = version
547
548    @property
549    def leaf_events(self) -> Tuple[_ProfilerEvent, ...]:
550        return self._leaf_events
551
552    @staticmethod
553    def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]:
554        """Partially traverse the op tree and extract top level ops.
555
556        Consider the following code:
557        ```
558        with record_function("My annotation"):
559            x.zero_()
560            y.zero_()
561        ```
562
563        The op tree (assuming no Autograd) will look like:
564          <Python context>
565            TorchOp: "My annotation"
566              TorchOp: zero_
567                TorchOp: fill_
568              TorchOp: zero_
569                TorchOp: fill_
570
571        The recursive structure of operator calls makes data flow unwieldy.
572        In order to simplify analysis we would like to select the highest level
573        ops to represent in the graph. In this case those are the `zero_` ops;
574        the fact that `fill_` is called is an implementation detail. We also
575        do not want to group everything under "My annotation" as this could
576        create overly coarse bundles and lose critical semantics.
577
578        To address this issue we walk over the graph and select the topmost
579        torch ops ** which match at least one operator schema **. These form
580        the leaves of the first pass through the op tree. (As well as any
581        allocations or frees which do are not part of a kernel.) These events
582        form the logical nodes in our data flow graph.
583        """
584
585        leaf_events: List[_ProfilerEvent] = []
586
587        def leaf_op(e: _ProfilerEvent) -> bool:
588            return e.typed[0] == _EventType.TorchOp and (
589                e.typed[1].scope == RecordScope.BACKWARD_FUNCTION
590                or bool(SchemaMatcher.match_schemas(e.typed[1]))
591            )
592
593        def children_fn(e: _ProfilerEvent):
594            if leaf_op(e) or e.tag == _EventType.Allocation:
595                leaf_events.append(e)
596                return []
597
598            return e.children
599
600        for _ in op_tree.dfs(children_fn=children_fn):
601            pass
602
603        return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns))
604
605    def lookup(self, key: TensorKey) -> int:
606        version = self._active_version.setdefault(key, 0)
607        assert version is not None
608        return version
609
610    def bump(self, key: TensorKey) -> None:
611        prior_version = self._active_version.get(key, None)
612        assert prior_version is not None
613        self._active_version[key] = prior_version + 1
614
615    def delete(self, key: TensorKey) -> None:
616        assert self._active_version.setdefault(key, 0) is not None
617        self._active_version[key] = None
618
619
620@dataclasses.dataclass
621class CategoryElement:
622    by_id: Optional[Category] = None
623    by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
624    by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
625
626    # Used by unit tests to check internals. (And consequently by
627    # MemoryProfile.lookup) This should not be used in any other capacity.
628    _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
629
630
631@dataclasses.dataclass
632class CategoryDict:
633    _values: DefaultDict[int, CategoryElement] = dataclasses.field(
634        default_factory=lambda: collections.defaultdict(CategoryElement)
635    )
636
637    def set_by_id(self, key: TensorKey, category: Category) -> None:
638        self._values[key.id].by_id = category
639        self._values[key.id]._by_id_keyset.add(key)
640
641    def set_by_key(self, key: TensorKey, category: Category) -> None:
642        self._values[key.id].by_key[key] = category
643
644    def set_by_version(self, key: TensorKey, version: int, category: Category) -> None:
645        self._values[key.id].by_version[(key, version)] = category
646
647    def setdefault_by_version(
648        self, key: TensorKey, version: int, category: Category
649    ) -> None:
650        self._values[key.id].by_version.setdefault((key, version), category)
651
652    def get(self, key: Key, version: int) -> Optional[Category]:
653        if isinstance(key, Key) and not isinstance(key, TensorKey):
654            return None
655        element = self._values[key.id]
656        return (
657            element.by_id
658            or element.by_key.get(key, None)
659            or element.by_version.get((key, version), None)
660        )
661
662
663class MemoryProfile:
664    def __init__(self, result: _ProfilerResult) -> None:
665        self._op_tree = OpTree(result)
666        self._data_flow_graph = DataFlowGraph(self._op_tree)
667        self._size_map = SizeMap(self._op_tree)
668        self._categories = CategoryDict()
669
670        self._set_gradients_and_temporaries()
671        self._set_parameters_using_python_tracer()
672        self._set_inputs()
673        self._set_parameters_using_data_flow()
674        self._set_activations()
675        self._set_optimizer_state()
676        self._set_autograd_detail()
677
678    @property
679    def timeline(self) -> Tuple[Tuple[int, Action, KeyAndID, int], ...]:
680        output: List[Tuple[int, Action, KeyAndID, int]] = []
681        allocation_times: Dict[Tuple[TensorKey, bool], int] = {}
682        live_unknown: Dict[Tuple[int, torch.device], Literal[True]] = {}
683        for event in self._op_tree.dfs():
684            if event.typed[0] == _EventType.Allocation:
685                alloc_fields = event.typed[1]
686                alloc_size = alloc_fields.alloc_size
687                is_allocation = alloc_size > 0
688                t = event.start_time_ns
689
690                tkey = TensorKey.from_allocation(alloc_fields)
691                if tkey is not None:
692                    allocation_times[(tkey, is_allocation)] = t
693
694                else:
695                    key = Key(alloc_fields.device)
696                    ptr_and_device = (alloc_fields.ptr, key.device)
697                    if is_allocation:
698                        if ptr_and_device in live_unknown:
699                            output.append(
700                                (t, Action.INCREMENT_VERSION, (key, 0), alloc_size)
701                            )
702                        else:
703                            live_unknown[ptr_and_device] = True
704                            output.append((t, Action.CREATE, (key, 0), alloc_size))
705                    else:
706                        output.append((t, Action.DESTROY, (key, 0), -alloc_size))
707                        if not live_unknown.pop(ptr_and_device, False):
708                            output.append(
709                                (-1, Action.PREEXISTING, (key, 0), -alloc_size)
710                            )
711
712        snapshot = self._category_snapshot()
713        last_version = dict(sorted(snapshot.keys()))
714
715        events: List[Tuple[int, Action, TensorAndID]] = [
716            (-1, Action.PREEXISTING, (key, version))
717            for key, version in snapshot.keys()
718            if (key, True) not in allocation_times and version == 0
719        ]
720
721        for node in self._data_flow_graph.flow_nodes:
722            for key, edge in node._edges.items():
723                if edge.is_allocation:
724                    t = allocation_times[(key, True)]
725                    events.append((t, Action.CREATE, (key, 0)))
726
727                elif edge.mutated:
728                    t = node._event.start_time_ns
729                    version = edge.input_version
730                    assert version is not None
731                    events.append((t, Action.INCREMENT_VERSION, (key, version)))
732
733                if edge.is_deletion:
734                    t = allocation_times[(key, False)]
735                    events.append((t, Action.DESTROY, (key, last_version[key])))
736
737        output.extend(
738            (time, action, (key, version), self._size_map[key])
739            for time, action, (key, version) in events
740        )
741
742        output.sort(key=lambda x: (x[0], x[1].value))
743        return tuple(output)
744
745    def _is_gradient(self, *args, **kwargs) -> bool:
746        return self._categories.get(*args, **kwargs) == Category.GRADIENT
747
748    def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
749        all_tensor_versions: Set[TensorAndID] = set()
750
751        for node in self._data_flow_graph.flow_nodes:
752            all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
753            all_tensor_versions.update((key, 0) for key in node.intermediates)
754            all_tensor_versions.update(node.outputs.items())
755
756        for i in self._categories._values.values():
757            all_tensor_versions.update((key, 0) for key in i._by_id_keyset)
758
759        return {
760            (key, version): self._categories.get(key, version)
761            for key, version in sorted(all_tensor_versions)
762        }
763
764    def _any_version_depends_on_gradient(self) -> Set[int]:
765        """Extract IDs of Tensors which depend or will depend on a gradient.
766
767        Note that this weakened definition of "depends" requires us to loop
768        over the data flow graph multiple times because it allows dependency
769        information to flow backward through edges and removes the guarantee
770        that nodes are topologically sorted. (Or indeed, even that a valid
771        topological order exists.) Put another way, we have converted an
772        acyclic data flow graph into a cyclic graph and we are attempting to
773        partition cycles involving a gradient from the rest of the graph.
774        """
775        depends_on_gradient: Set[int] = set()
776        while True:
777            start_size = len(depends_on_gradient)
778            for node in self._data_flow_graph.flow_nodes:
779                ids = tuple(
780                    key.id
781                    for key, (_, version) in node.inputs.items()
782                    if self._categories.get(key, version)
783                    in (Category.GRADIENT, Category.PARAMETER)
784                    or key.id in depends_on_gradient
785                )
786
787                if ids:
788                    depends_on_gradient.update(ids)
789                    depends_on_gradient.update(key.id for key in node.outputs)
790
791            # We are guaranteed to exit because there is a finite set of
792            # TensorAndID pairs. In practice we do not expect to loop more than
793            # three times: once to identify the core parameter update loop,
794            # once to fold the first step into that loop, and a third time
795            # where no new elements are added.
796            if len(depends_on_gradient) == start_size:
797                return depends_on_gradient
798
799    def _set_gradients_and_temporaries(self) -> None:
800        """Mark Tensors which are unambiguous and simple to reason about."""
801
802        # Gradients are straightforward to detect. We directly check the
803        # `.grad` property in the Python tracer, and we can detect any new
804        # gradient Tensors from `AccumulateGrad` ops.
805        for event in self._op_tree.dfs():
806            for _, p_grad in extract_gradients(event):
807                self._categories.set_by_id(p_grad, Category.GRADIENT)
808
809        # Similarly, temporary Tensors are easy to identify and are useful to
810        # flag since they can make memory use "spikier" than one would
811        # otherwise expect.
812        for node in self._data_flow_graph.flow_nodes:
813            for i in node.intermediates:
814                self._categories.set_by_key(i, Category.TEMPORARY)
815
816    def _set_parameters_using_python_tracer(self) -> None:
817        for event in self._op_tree.dfs():
818            for p in extract_parameters(event):
819                if p is not None:
820                    self._categories.set_by_id(p, Category.PARAMETER)
821
822    def _set_inputs(self) -> None:
823        """Mark inputs based on which Tensors are updated using gradients.
824
825        The process for differentiating between inputs and activations is more
826        involved. Most Tensors in a training loop depend on at least one
827        gradient: parameters depend on them through updates, and activations
828        and optimizer state depend on them transitively through parameters.
829        Critically, we do not need to know which Tensors are parameters to
830        apply this method; we can simply walk the data flow graph to build the
831        set of all values which depend on a gradient and then obtain the set
832        of inputs from the conjugate set.
833
834        There is, however, one hiccup. The first time we see a parameter is
835        generally on the forward pass of the first step. We know from
836        inspection of the data flow graph that v1 of that Tensor depends on
837        a gradient (provided we profile an optimizer step), but not v0. To
838        address this problem we weaken the definition of "depends on a
839        gradient" to "any version of this Tensor depends on a gradient",
840        which in turn strengthens the criteria for the input set enough to
841        filter the activations in the forward pass of the first step."""
842
843        # All of this analysis is predicated on using at least one training
844        # step (or parameters from the python tracer) to partition the graph.
845        # Absent that we cannot determine which Tensors are inputs and which
846        # ones are part of the model.
847        depends_on_gradient = self._any_version_depends_on_gradient()
848
849        # We only want to annotate Tensors which actually contribute to the
850        # model calculation.
851        produces_gradient: Set[TensorAndID] = set()
852        for node in reversed(self._data_flow_graph.flow_nodes):
853            tensors = {(key, version) for key, (_, version) in node.inputs.items()}
854            tensors |= node.outputs.items()
855            if any(
856                self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER)
857                or i in produces_gradient
858                for i in tensors
859            ):
860                produces_gradient |= tensors
861
862        # Don't include Tensors created in the backward pass, as these are
863        # generally Autograd implementation details rather than proper inputs.
864        input_candidates = produces_gradient.copy()
865        for node in self._data_flow_graph.flow_nodes:
866            if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
867                input_candidates -= set(node.outputs.items())
868
869        for key, version in input_candidates:
870            if key.id not in depends_on_gradient:
871                self._categories.setdefault_by_version(key, version, Category.INPUT)
872
873    def _set_parameters_using_data_flow(self) -> None:
874        """Deduce which Tensors are parameters.
875
876        Consider the following code for the step of SGD with momentum
877        (nesterov=False), where `d_p` is the gradient of `param` and `buf` is
878        the momentum buffer.
879        ```
880          buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
881          d_p = buf
882          param.add_(d_p, alpha=-lr)
883        ```
884        Both `param` and `buf` take a gradient and perform an in-place update.
885
886        The python tracer will inspect calls to `nn.Module.forward` and
887        `optim.Optimizer.step` to extract parameter and optimizer state
888        respectively (including parameters), so this is generally a non-issue.
889
890        However as a fallback we can also exploit several properties of
891        parameters to distinguish them from other model state.
892
893        First, they are directly used in the forward pass. (At this point we
894        haven't established which parts of the graph correspond to the forward
895        pass but we can deduce enough to suffice.) Some mutable state such as
896        batch norm moving averages also contribute to the forward pass, but
897        optimizer state does not.
898
899        Second, a parameter is by definition used to compute at least one
900        gradient and depends on at least one gradient.
901        """
902        snapshot = self._category_snapshot()
903
904        # Determine which Tensors might be parameters based on forward pass
905        # data flow. Note this these are only candidates; we filter nodes that
906        # we know are part of the backward pass but that doesn't guarantee that
907        # they are part of the forward pass.
908        candidate_parameters: Set[TensorAndID] = set()
909        candidate_fwd_tensors: Set[TensorAndID] = {
910            i for i, category in snapshot.items() if category == Category.INPUT
911        }
912
913        for node in self._data_flow_graph.flow_nodes:
914            inputs = {(key, value) for key, (_, value) in node.inputs.items()}
915            if (
916                # Don't check nodes in the backward pass.
917                RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
918                and not any(self._is_gradient(*i) for i in inputs)
919                and not any(self._is_gradient(*i) for i in node.outputs.items())
920                #
921                # and only check nodes which depend on an input.
922                and candidate_fwd_tensors.intersection(inputs)
923            ):
924                candidate_fwd_tensors |= node.outputs.items()
925                candidate_parameters |= inputs.difference(candidate_fwd_tensors)
926
927        # Require that each parameter eventually contributes to the value of a gradient
928        used_for_gradient: Set[TensorAndID] = set()
929        for node in reversed(self._data_flow_graph.flow_nodes):
930            if any(
931                self._is_gradient(*i) or i in used_for_gradient
932                for i in node.outputs.items()
933            ):
934                used_for_gradient.update(
935                    (key, version) for key, (_, version) in node.inputs.items()
936                )
937        candidate_parameters.intersection_update(used_for_gradient)
938
939        # and depends on a gradient.
940        parameter_keys = {key.id for key, _ in candidate_parameters}
941        parameter_keys &= self._any_version_depends_on_gradient()
942
943        for key, _ in snapshot.keys():
944            if key.id in parameter_keys:
945                self._categories.set_by_id(key, Category.PARAMETER)
946
947    def _set_activations(self) -> None:
948        """Flood the graph to identify activations."""
949
950        required = {Category.INPUT, Category.ACTIVATION}
951        also_allowed = {Category.PARAMETER, Category.TEMPORARY}
952        for node in self._data_flow_graph.flow_nodes:
953            inputs = {(key, value) for key, (_, value) in node.inputs.items()}
954            input_categories = {self._categories.get(*i) for i in inputs}
955
956            if (
957                (input_categories & required)
958                and not (input_categories - (required | also_allowed))
959                #
960                # Stop filling when we reach the backward pass.
961                and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
962            ):
963                for i in node.outputs.items():
964                    self._categories.setdefault_by_version(*i, Category.ACTIVATION)
965
966    def _set_optimizer_state(self) -> None:
967        for event in self._op_tree.dfs():
968            if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer:
969                parameters = event.typed[1].optimizer.parameters
970                for _, t in it.chain(*[state for _, _, state in parameters]):
971                    key = TensorKey.from_tensor(t)
972                    if key is not None:
973                        self._categories.set_by_id(key, Category.OPTIMIZER_STATE)
974
975    def _set_autograd_detail(self):
976        prior = {None, Category.AUTOGRAD_DETAIL}
977        for node in self._data_flow_graph.flow_nodes:
978            if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
979                for key, version in node.outputs.items():
980                    if version == 0 or self._categories.get(key, version - 1) in prior:
981                        self._categories.setdefault_by_version(
982                            key, version, Category.AUTOGRAD_DETAIL
983                        )
984
985
986class MemoryProfileTimeline:
987    def __init__(self, memory_profile):
988        """The minimum representation of the memory profile timeline
989        includes the memory timeline and categories. The timeline
990        consists of [timestamp, action, (TensorKey, version), numbytes]
991        elements, to denote any actions (pre-existing, create, destroy,
992        or increment_version) that occurred to a specific Tensor for a
993        chunk of memory. The categories help map each (TensorKey,
994        version) pair into a category."""
995        self.timeline = memory_profile.timeline
996        self.categories = memory_profile._categories
997
998    def _coalesce_timeline(self, device_str):
999        """Convert the memory timeline and categories into a memory plot
1000        consisting of timestamps and their respective sizes by category
1001        for a given device.
1002
1003        Input: device
1004        Output: [timestamps, sizes by category]
1005        """
1006        device = torch.device(device_str)
1007        times: List[int] = []
1008        sizes: List[List[int]] = []
1009
1010        def update(key, version, delta):
1011            category = (
1012                self.categories.get(key, version)
1013                if isinstance(key, TensorKey)
1014                else None
1015            )
1016            index = _CATEGORY_TO_INDEX[category] + 1
1017            sizes[-1][index] += int(delta)
1018
1019        t_min = -1
1020        for t, action, (key, version), numbytes in self.timeline:
1021            if key.device != device:
1022                continue
1023
1024            # Convert timestamps from ns to us, to match trace events.
1025            if t != -1:
1026                t = int(t / 1000)
1027
1028            # Save the smallest timestamp to populate pre-existing allocs.
1029            if t_min == -1 or (t < t_min and t > 0):
1030                t_min = t
1031
1032            # Handle timestep
1033            if len(times) == 0:
1034                times.append(t)
1035                sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX])
1036
1037            elif t != times[-1]:
1038                times.append(t)
1039                sizes.append(sizes[-1].copy())
1040
1041            # Handle memory and categories
1042            if action in (Action.PREEXISTING, Action.CREATE):
1043                update(key, version, numbytes)
1044
1045            elif action == Action.INCREMENT_VERSION:
1046                update(key, version, -numbytes)
1047                update(key, version + 1, numbytes)
1048
1049            elif action == Action.DESTROY:
1050                update(key, version, -numbytes)
1051
1052            else:
1053                raise ValueError(f"Unknown action: {action}")
1054
1055        times = [t_min if t < 0 else t for t in times]
1056        return times, sizes
1057
1058    def export_memory_timeline(self, path, device_str) -> None:
1059        """Saves the memory timeline as [times, sizes by category]
1060        as a JSON formatted file to the given path for the given
1061        device."""
1062        times, sizes = self._coalesce_timeline(device_str)
1063        # TODO: Write a faster serialize (orjson not available in CI)
1064        import json
1065
1066        with open(path, "w") as f:
1067            json.dump([times, sizes], f)
1068
1069    def export_memory_timeline_raw(self, path, device_str) -> None:
1070        """Saves the memory timeline as raw memory event tuples in the
1071        form of (timestamp, action, numbytes, category)
1072        as a JSON formatted file to the given path for the given
1073        device."""
1074        device = torch.device(device_str)
1075        raw_events: List[Tuple[int, int, int, int]] = []
1076
1077        def get_category_index(key, version):
1078            category = (
1079                self.categories.get(key, version)
1080                if isinstance(key, TensorKey)
1081                else None
1082            )
1083            return _CATEGORY_TO_INDEX[category]
1084
1085        for t, action, (key, version), numbytes in self.timeline:
1086            if key.device != device:
1087                continue
1088
1089            if action in (Action.PREEXISTING, Action.CREATE):
1090                raw_events.append(
1091                    (
1092                        t,
1093                        _ACTION_TO_INDEX[action],
1094                        numbytes,
1095                        get_category_index(key, version),
1096                    )
1097                )
1098
1099            elif action == Action.INCREMENT_VERSION:
1100                raw_events.append(
1101                    (
1102                        t,
1103                        _ACTION_TO_INDEX[action],
1104                        -numbytes,
1105                        get_category_index(key, version),
1106                    )
1107                )
1108                raw_events.append(
1109                    (
1110                        t,
1111                        _ACTION_TO_INDEX[action],
1112                        numbytes,
1113                        get_category_index(key, version + 1),
1114                    )
1115                )
1116
1117            elif action == Action.DESTROY:
1118                raw_events.append(
1119                    (
1120                        t,
1121                        _ACTION_TO_INDEX[action],
1122                        -numbytes,
1123                        get_category_index(key, version),
1124                    )
1125                )
1126
1127            else:
1128                raise ValueError(f"Unknown action: {action}")
1129
1130        import json
1131
1132        with open(path, "w") as f:
1133            json.dump(raw_events, f)
1134
1135    def export_memory_timeline_html(
1136        self, path, device_str, figsize=(20, 12), title=None
1137    ) -> None:
1138        """Exports the memory timeline as an HTML file which contains
1139        the memory timeline plot embedded as a PNG file."""
1140        # Check if user has matplotlib installed, return gracefully if not.
1141        import importlib.util
1142
1143        matplotlib_spec = importlib.util.find_spec("matplotlib")
1144        if matplotlib_spec is None:
1145            print(
1146                "export_memory_timeline_html failed because matplotlib was not found."
1147            )
1148            return
1149
1150        from base64 import b64encode
1151        from os import remove
1152        from tempfile import NamedTemporaryFile
1153
1154        import matplotlib.pyplot as plt
1155        import numpy as np
1156
1157        mt = self._coalesce_timeline(device_str)
1158        times, sizes = np.array(mt[0]), np.array(mt[1])
1159        # For this timeline, start at 0 to match Chrome traces.
1160        t_min = min(times)
1161        times -= t_min
1162        stacked = np.cumsum(sizes, axis=1) / 1024**3
1163        device = torch.device(device_str)
1164        max_memory_allocated = torch.cuda.max_memory_allocated(device)
1165        max_memory_reserved = torch.cuda.max_memory_reserved(device)
1166
1167        # Plot memory timeline as stacked data
1168        fig = plt.figure(figsize=figsize, dpi=80)
1169        axes = fig.gca()
1170        for category, color in _CATEGORY_TO_COLORS.items():
1171            i = _CATEGORY_TO_INDEX[category]
1172            axes.fill_between(
1173                times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7
1174            )
1175        fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS])
1176        # Usually training steps are in magnitude of ms.
1177        axes.set_xlabel("Time (ms)")
1178        axes.set_ylabel("Memory (GB)")
1179        title = "\n\n".join(
1180            ([title] if title else [])
1181            + [
1182                f"Max memory allocated: {max_memory_allocated/(1024**3):.2f} GiB \n"
1183                f"Max memory reserved: {max_memory_reserved/(1024**3):.2f} GiB"
1184            ]
1185        )
1186        axes.set_title(title)
1187
1188        # Embed the memory timeline image into the HTML file
1189        tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False)
1190        tmpfile.close()
1191        fig.savefig(tmpfile.name, format="png")
1192
1193        with open(tmpfile.name, "rb") as tmp:
1194            encoded = b64encode(tmp.read()).decode("utf-8")
1195            html = f"""<html>
1196<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head>
1197<body>
1198  <img src='data:image/png;base64,{encoded}'>
1199</body>
1200</html>"""
1201
1202            with open(path, "w") as f:
1203                f.write(html)
1204        remove(tmpfile.name)
1205