1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from collections import OrderedDict 10from typing import cast, Mapping, Optional 11 12import torch 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.dialects.edge._ops import EdgeOpOverload 15from torch._export.utils import ( 16 get_buffer, 17 get_lifted_tensor_constant, 18 get_param, 19 is_buffer, 20 is_lifted_tensor_constant, 21 is_param, 22) 23from torch._guards import detect_fake_mode 24from torch.export import ExportedProgram 25from torch.export.exported_program import InputKind, InputSpec, TensorArgument 26from torch.utils import _pytree as pytree 27 28 29# Avoid propagating constants for `exir.ops.edge.aten.full.default`. 30# Propagating aten.full can significantly increase compiled model size. 31_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default} 32 33_PRIMITIVE_TYPES = ( 34 float, 35 int, 36 bool, 37 str, 38 torch.Tensor, 39 torch.device, 40 torch.dtype, 41 torch.layout, 42) 43 44 45def is_const( 46 arg, 47 exported_program: ExportedProgram, 48 const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], 49) -> bool: 50 if isinstance(arg, (tuple, list)): 51 return all(is_const(x, exported_program, const_node_to_tensor) for x in arg) 52 elif isinstance(arg, dict): 53 return all( 54 is_const(x, exported_program, const_node_to_tensor) for x in arg.values() 55 ) 56 elif isinstance(arg, _PRIMITIVE_TYPES): 57 return True 58 elif not isinstance(arg, torch.fx.Node): 59 return False 60 elif arg in const_node_to_tensor: 61 return True 62 return False 63 64 65def get_data( 66 arg, 67 exported_program: ExportedProgram, 68 const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], 69): 70 if isinstance(arg, (tuple, list)): 71 return type(arg)( 72 get_data(x, exported_program, const_node_to_tensor) for x in arg 73 ) 74 elif isinstance(arg, _PRIMITIVE_TYPES): 75 return arg 76 elif arg in const_node_to_tensor: 77 return const_node_to_tensor[arg] 78 return None 79 80 81def get_constant_placeholder_dict( 82 exported_program: ExportedProgram, 83) -> OrderedDict[torch.fx.Node, torch.Tensor]: 84 """ 85 Returns a dictionary of placeholder node -> constant tensor. 86 """ 87 const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() 88 for node in exported_program.graph.nodes: 89 if node.op != "placeholder": 90 continue 91 92 if is_param(exported_program, node): 93 const_node_to_tensor[node] = cast( 94 torch.Tensor, get_param(exported_program, node) 95 ) 96 elif is_buffer(exported_program, node): 97 const_node_to_tensor[node] = cast( 98 torch.Tensor, get_buffer(exported_program, node) 99 ) 100 elif is_lifted_tensor_constant(exported_program, node): 101 const_node_to_tensor[node] = cast( 102 torch.Tensor, get_lifted_tensor_constant(exported_program, node) 103 ) 104 return const_node_to_tensor 105 106 107def get_propagated_const_tensor_dict( 108 exported_program: ExportedProgram, 109 custom_skip_targets: Optional[set[EdgeOpOverload]], 110) -> OrderedDict[torch.fx.Node, torch.Tensor]: 111 """ 112 Propagates constants and returns a dictionary of node->constant tensors. 113 """ 114 # Initialize dict with all constant placeholders. 115 const_node_to_tensor = get_constant_placeholder_dict(exported_program) 116 117 if custom_skip_targets is not None: 118 all_skip_targets = custom_skip_targets 119 else: 120 # Default set of targets to skip. 121 all_skip_targets = _DEFAULT_SKIP_TARGETS 122 123 for node in exported_program.graph.nodes: 124 if node.op != "call_function" or node.target in all_skip_targets: 125 continue 126 127 if not is_const( 128 node.args, 129 exported_program, 130 const_node_to_tensor, 131 ) or not is_const( 132 node.kwargs, 133 exported_program, 134 const_node_to_tensor, 135 ): 136 continue 137 138 args_data, kwargs_data = pytree.tree_map( 139 lambda x: get_data(x, exported_program, const_node_to_tensor), 140 (node.args, node.kwargs), 141 ) 142 # Disable grad for constant propagation, otherwise the generated tensor can't be copied 143 # because of the grad_fn. 144 with torch.no_grad(): 145 # Execute the `node.target` and create a new propagated constant tensor. 146 prop_constant_tensor = node.target(*args_data, **kwargs_data) 147 const_node_to_tensor[node] = prop_constant_tensor 148 149 return const_node_to_tensor 150 151 152def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node: 153 """Returns the first user input node in the graph.""" 154 first_user_input = None 155 for node in exported_program.graph.nodes: 156 if ( 157 node.op == "placeholder" 158 and node.name in exported_program.graph_signature.user_inputs 159 ): 160 first_user_input = node 161 break 162 return first_user_input 163 164 165def replace_with_constant_node( 166 node: torch.fx.Node, 167 prop_constant_tensor: torch.Tensor, 168 first_user_input: torch.fx.Node, 169 fake_mode, 170 exported_program: ExportedProgram, 171) -> tuple[torch.fx.Node, str]: 172 # Add `prop_constant_tensor` to program.state_dict. 173 prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}" 174 exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor 175 176 # Insert a new placeholder node for the propagated constant tensor. 177 with exported_program.graph.inserting_before(first_user_input): 178 const_placeholder_node = exported_program.graph.placeholder( 179 prop_constant_tensor_fqn 180 ) 181 182 # Update the meta data of the new placeholder (buffer) node. 183 for k, v in node.meta.items(): 184 const_placeholder_node.meta[k] = v 185 const_placeholder_node.meta["val"] = fake_mode.from_tensor( 186 prop_constant_tensor, static_shapes=True 187 ) 188 const_placeholder_node.meta["val"].constant = prop_constant_tensor 189 190 # Replace the original node with the new constant node. 191 node.replace_all_uses_with(const_placeholder_node) 192 exported_program.graph.erase_node(node) 193 194 return const_placeholder_node, prop_constant_tensor_fqn 195 196 197def get_fake_mode(exported_program: ExportedProgram): 198 fake_mode = detect_fake_mode( 199 tuple( 200 node.meta["val"] 201 for node in exported_program.graph.nodes 202 if node.op == "placeholder" 203 ) 204 ) 205 assert fake_mode is not None 206 return fake_mode 207 208 209def erase_constant_node( 210 exported_program: ExportedProgram, 211 node: torch.fx.Node, 212) -> None: 213 # Remove corresponding tensor from param/constants dict. 214 signature = exported_program.graph_signature 215 if name := signature.inputs_to_parameters.get(node.name, None): 216 exported_program.state_dict.pop(name, None) 217 elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None): 218 exported_program.constants.pop(name, None) 219 elif name := signature.inputs_to_buffers.get(node.name, None): 220 exported_program.constants.pop(name, None) 221 exported_program.state_dict.pop(name, None) 222 223 # Remove from graph. 224 exported_program.graph.erase_node(node) 225 226 227def create_constant_nodes_and_return_specs( 228 const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], 229 exported_program: ExportedProgram, 230) -> dict[str, InputSpec]: 231 """ 232 Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict. 233 """ 234 name_to_spec_dict: dict[str, InputSpec] = {} 235 236 fake_mode = get_fake_mode(exported_program) 237 first_user_input = get_first_user_input(exported_program) 238 239 # Iterate over nodes in reverse order. 240 for node, prop_constant_tensor in reversed(const_node_to_tensor.items()): 241 if all(x in const_node_to_tensor for x in node.users): 242 # All users of this constant node are also constant, so we don't need to create a new constant node. 243 erase_constant_node(exported_program, node) 244 continue 245 246 if node.op == "placeholder": 247 continue 248 249 const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node( 250 node, prop_constant_tensor, first_user_input, fake_mode, exported_program 251 ) 252 253 # Create input spec for lifted constant. 254 name_to_spec_dict[const_placeholder_node.name] = InputSpec( 255 kind=InputKind.CONSTANT_TENSOR, 256 arg=TensorArgument(name=const_placeholder_node.name), 257 target=prop_constant_tensor_fqn, 258 persistent=True, 259 ) 260 return name_to_spec_dict 261 262 263def constant_prop_pass( 264 exported_program: ExportedProgram, 265 custom_skip_targets: Optional[set[EdgeOpOverload]] = None, 266) -> ExportedProgram: 267 """ 268 This pass is for constant propagation for Exported Program with lifted parameters, 269 as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph. 270 271 Args: 272 exported_program: The ExportedProgram to perform constant propagation on. 273 custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation. 274 275 Returns: 276 The modified ExportedProgram with constant propagation applied. 277 """ 278 if ( 279 len([node for node in exported_program.graph.nodes if node.op == "placeholder"]) 280 == 0 281 ): 282 return exported_program 283 284 has_control_flow = [ 285 node 286 for node in exported_program.graph.nodes 287 if node.target == torch.ops.higher_order.cond 288 ] 289 if len(has_control_flow) > 0: 290 raise RuntimeError("constant_prop_pass for control flow is not supported yet.") 291 292 const_node_to_tensor = get_propagated_const_tensor_dict( 293 exported_program, custom_skip_targets 294 ) 295 296 # Get old input specs. 297 name_to_spec_dict = { 298 s.arg.name: s for s in exported_program.graph_signature.input_specs 299 } 300 # Add the new constants to input specs dict. 301 name_to_spec_dict.update( 302 create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program) 303 ) 304 305 # Generate new input spec. 306 new_input_specs = [] 307 for node in exported_program.graph.nodes: 308 if node.op != "placeholder": 309 continue 310 new_input_specs.append(name_to_spec_dict[node.name]) 311 exported_program.graph_signature.input_specs = new_input_specs 312 313 # Cleanup the graph. 314 exported_program.graph.eliminate_dead_code() 315 exported_program.graph_module.recompile() 316 317 return exported_program 318