xref: /aosp_15_r20/external/executorch/exir/passes/memory_planning_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import warnings
9from typing import Callable, List, Optional
10
11import torch
12from executorch.exir.error import internal_assert
13from executorch.exir.memory import alloc
14from executorch.exir.memory_planning import (
15    _is_out_var_node,
16    apply_algo,
17    get_node_tensor_specs,
18    greedy,
19    Verifier,
20)
21from executorch.exir.operator.convert import get_out_args_from_opoverload
22from executorch.exir.pass_base import PassBase, PassResult
23from executorch.exir.tensor import ALIGNMENT
24from torch.export.exported_program import ExportGraphSignature
25
26
27class MemoryPlanningPass(PassBase):
28    def __init__(
29        self,
30        memory_planning_algo: Callable[..., List[int]] = greedy,
31        allow_lifetime_and_storage_overlap: bool = False,
32        alloc_graph_input: bool = True,
33        alloc_graph_output: bool = True,
34        alignment: int = ALIGNMENT,
35    ) -> None:
36        r"""
37        alloc_graph_input/alloc_graph_output will have 4 different combinations
38        to control if the memory planning algorithm need allocate memory for
39        the graph input/output. The default behavior is the algorithm will allocate
40        memory for both graph input and output.
41        """
42        self.memory_planning_algo = memory_planning_algo
43        self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
44        self.alloc_graph_input = alloc_graph_input
45        self.alloc_graph_output = alloc_graph_output
46        self.alignment = alignment
47
48    def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
49        """
50        Pass for setting all of the alloc node's specs. These nodes are created
51        in the ToOutVarPass but do not have a spec.
52
53        TODO(shunting): we probablly should setup the spec for memory.alloc node
54          in the ToOutVarPass
55        """
56        for subgm in graph_module.modules():
57            if not isinstance(subgm, torch.fx.GraphModule):
58                continue
59            for node in subgm.graph.nodes:
60                if _is_out_var_node(node):
61                    out_arg_names = get_out_args_from_opoverload(node.target)
62                    if len(out_arg_names) == 1:
63                        out_alloc_node = node.kwargs[out_arg_names[0]]
64                        out_alloc_node.meta["spec"] = node.meta["spec"]
65                        continue
66                    specs = get_node_tensor_specs(node)
67                    for i, out_arg in enumerate(out_arg_names):
68                        out_alloc_node = node.kwargs[out_arg]
69                        if out_alloc_node is None:
70                            warnings.warn(
71                                f"Function {node.target}'s {out_arg} kwarg value is None",
72                                stacklevel=1,
73                            )
74                            continue
75                        internal_assert(
76                            out_alloc_node.op == "call_function"
77                            and out_alloc_node.target == alloc,
78                            f"Out-var's node {out_alloc_node} has op {out_alloc_node.op} and target {out_alloc_node.target}",
79                        )
80                        internal_assert(
81                            "spec" not in out_alloc_node.meta,
82                            f"Out-var's allocation node {out_alloc_node} already has a spec assigned",
83                        )
84                        out_alloc_node.meta["spec"] = specs[i]
85
86    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
87        return self.run(graph_module)
88
89    def run(
90        self,
91        graph_module: torch.fx.GraphModule,
92        graph_signature: Optional[ExportGraphSignature] = None,
93    ) -> PassResult:
94        """
95        A pass for memory planning. The actual algorithm used will be picked by
96        memory_planning_algo
97        """
98        self._set_alloc_node_spec(graph_module)
99        # TODO(shunting) if people have concern of adding a field to GraphModule
100        # directly, we should define a GraphModule subclass that we can add our
101        # customized fields. Using the graph_module object to convey information across
102        # passes/stages is quite natural and avoid yet another 'context' data structure
103        # to do the job.
104        _ = apply_algo(
105            self.memory_planning_algo,
106            graph_module,
107            self.alignment,
108            graph_signature,
109            self.alloc_graph_input,
110            self.alloc_graph_output,
111        )
112
113        # TODO: make the verifier do the work recursively to handle
114        # control flow
115        verifier = Verifier(
116            graph_module,
117            self.alloc_graph_input,
118            self.alloc_graph_output,
119            graph_signature,
120        )
121
122        if logging.getLogger().isEnabledFor(logging.DEBUG):
123            num_reuse_pairs = verifier.verify_storage_reuse(
124                self.allow_lifetime_and_storage_overlap
125            )
126            logging.debug(
127                f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
128            )
129        verifier.verify_graph_input_output()
130        return PassResult(graph_module, True)
131