1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import warnings 9from typing import Callable, List, Optional 10 11import torch 12from executorch.exir.error import internal_assert 13from executorch.exir.memory import alloc 14from executorch.exir.memory_planning import ( 15 _is_out_var_node, 16 apply_algo, 17 get_node_tensor_specs, 18 greedy, 19 Verifier, 20) 21from executorch.exir.operator.convert import get_out_args_from_opoverload 22from executorch.exir.pass_base import PassBase, PassResult 23from executorch.exir.tensor import ALIGNMENT 24from torch.export.exported_program import ExportGraphSignature 25 26 27class MemoryPlanningPass(PassBase): 28 def __init__( 29 self, 30 memory_planning_algo: Callable[..., List[int]] = greedy, 31 allow_lifetime_and_storage_overlap: bool = False, 32 alloc_graph_input: bool = True, 33 alloc_graph_output: bool = True, 34 alignment: int = ALIGNMENT, 35 ) -> None: 36 r""" 37 alloc_graph_input/alloc_graph_output will have 4 different combinations 38 to control if the memory planning algorithm need allocate memory for 39 the graph input/output. The default behavior is the algorithm will allocate 40 memory for both graph input and output. 41 """ 42 self.memory_planning_algo = memory_planning_algo 43 self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap 44 self.alloc_graph_input = alloc_graph_input 45 self.alloc_graph_output = alloc_graph_output 46 self.alignment = alignment 47 48 def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: 49 """ 50 Pass for setting all of the alloc node's specs. These nodes are created 51 in the ToOutVarPass but do not have a spec. 52 53 TODO(shunting): we probablly should setup the spec for memory.alloc node 54 in the ToOutVarPass 55 """ 56 for subgm in graph_module.modules(): 57 if not isinstance(subgm, torch.fx.GraphModule): 58 continue 59 for node in subgm.graph.nodes: 60 if _is_out_var_node(node): 61 out_arg_names = get_out_args_from_opoverload(node.target) 62 if len(out_arg_names) == 1: 63 out_alloc_node = node.kwargs[out_arg_names[0]] 64 out_alloc_node.meta["spec"] = node.meta["spec"] 65 continue 66 specs = get_node_tensor_specs(node) 67 for i, out_arg in enumerate(out_arg_names): 68 out_alloc_node = node.kwargs[out_arg] 69 if out_alloc_node is None: 70 warnings.warn( 71 f"Function {node.target}'s {out_arg} kwarg value is None", 72 stacklevel=1, 73 ) 74 continue 75 internal_assert( 76 out_alloc_node.op == "call_function" 77 and out_alloc_node.target == alloc, 78 f"Out-var's node {out_alloc_node} has op {out_alloc_node.op} and target {out_alloc_node.target}", 79 ) 80 internal_assert( 81 "spec" not in out_alloc_node.meta, 82 f"Out-var's allocation node {out_alloc_node} already has a spec assigned", 83 ) 84 out_alloc_node.meta["spec"] = specs[i] 85 86 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 87 return self.run(graph_module) 88 89 def run( 90 self, 91 graph_module: torch.fx.GraphModule, 92 graph_signature: Optional[ExportGraphSignature] = None, 93 ) -> PassResult: 94 """ 95 A pass for memory planning. The actual algorithm used will be picked by 96 memory_planning_algo 97 """ 98 self._set_alloc_node_spec(graph_module) 99 # TODO(shunting) if people have concern of adding a field to GraphModule 100 # directly, we should define a GraphModule subclass that we can add our 101 # customized fields. Using the graph_module object to convey information across 102 # passes/stages is quite natural and avoid yet another 'context' data structure 103 # to do the job. 104 _ = apply_algo( 105 self.memory_planning_algo, 106 graph_module, 107 self.alignment, 108 graph_signature, 109 self.alloc_graph_input, 110 self.alloc_graph_output, 111 ) 112 113 # TODO: make the verifier do the work recursively to handle 114 # control flow 115 verifier = Verifier( 116 graph_module, 117 self.alloc_graph_input, 118 self.alloc_graph_output, 119 graph_signature, 120 ) 121 122 if logging.getLogger().isEnabledFor(logging.DEBUG): 123 num_reuse_pairs = verifier.verify_storage_reuse( 124 self.allow_lifetime_and_storage_overlap 125 ) 126 logging.debug( 127 f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors" 128 ) 129 verifier.verify_graph_input_output() 130 return PassResult(graph_module, True) 131